diff options
650 files changed, 25346 insertions, 5387 deletions
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java deleted file mode 100644 index 5a509d77431..00000000000 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.instanceproviderservice; - -import com.google.inject.Inject; -import com.yahoo.cloud.config.ConfigserverConfig; -import com.yahoo.component.AbstractComponent; -import com.yahoo.config.provision.Zone; -import com.yahoo.jdisc.http.ssl.SslKeyStoreConfigurator; -import com.yahoo.jdisc.http.ssl.SslKeyStoreContext; -import com.yahoo.log.LogLevel; -import com.yahoo.vespa.athenz.api.AthenzService; -import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; -import com.yahoo.vespa.athenz.client.zts.Identity; -import com.yahoo.vespa.athenz.client.zts.ZtsClient; -import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; -import com.yahoo.vespa.athenz.tls.KeyStoreBuilder; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.utils.SiaUtils; -import com.yahoo.vespa.defaults.Defaults; -import com.yahoo.vespa.hosted.athenz.instanceproviderservice.config.AthenzProviderServiceConfig; - -import java.net.URI; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.security.KeyPair; -import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.PrivateKey; -import java.security.PublicKey; -import java.security.cert.X509Certificate; -import java.time.Duration; -import java.time.Instant; -import java.util.List; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.logging.Logger; - -import static com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl.Utils.getZoneConfig; - -/** - * A component that is responsible for retrieving an Athenz TLS certificate and configuring the configserver to use - * that certificate for its HTTPS endpoint. - * - * @author bjorncs - */ -@SuppressWarnings("unused") // Component injected into Jetty connector factory -public class AthenzSslKeyStoreConfigurator extends AbstractComponent implements SslKeyStoreConfigurator { - private static final Logger log = Logger.getLogger(AthenzSslKeyStoreConfigurator.class.getName()); - private static final String CERTIFICATE_ALIAS = "athenz"; - private static final Duration EXPIRATION_MARGIN = Duration.ofHours(6); - private static final Path VESPA_SIA_DIRECTORY = Paths.get(Defaults.getDefaults().underVespaHome("var/vespa/sia")); - - private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(); - private final ZtsClient ztsClient; - private final KeyProvider keyProvider; - private final AthenzProviderServiceConfig.Zones zoneConfig; - private final Duration updatePeriod; - private final AthenzService configserverIdentity; - private volatile KeyStoreAndPassword currentKeyStore; - - @Inject - public AthenzSslKeyStoreConfigurator(ServiceIdentityProvider bootstrapIdentity, - KeyProvider keyProvider, - AthenzProviderServiceConfig config, - Zone zone, - ConfigserverConfig configserverConfig) { - AthenzProviderServiceConfig.Zones zoneConfig = getZoneConfig(config, zone); - AthenzService configserverIdentity = new AthenzService(zoneConfig.domain(), zoneConfig.serviceName()); - Duration updatePeriod = Duration.ofDays(config.updatePeriodDays()); - DefaultZtsClient ztsClient = new DefaultZtsClient(URI.create(zoneConfig.ztsUrl()), bootstrapIdentity); - this.ztsClient = ztsClient; - this.keyProvider = keyProvider; - this.zoneConfig = zoneConfig; - this.currentKeyStore = initializeKeystore(configserverIdentity, keyProvider, ztsClient, zoneConfig, updatePeriod); - this.updatePeriod = updatePeriod; - this.configserverIdentity = configserverIdentity; - } - - private static KeyStoreAndPassword initializeKeystore(AthenzService configserverIdentity, - KeyProvider keyProvider, - ZtsClient ztsClient, - AthenzProviderServiceConfig.Zones keystoreCacheDirectory, - Duration updatePeriod) { - return tryReadKeystoreFile(configserverIdentity, updatePeriod) - .orElseGet(() -> downloadCertificate(configserverIdentity, keyProvider, ztsClient, keystoreCacheDirectory)); - } - - private static Optional<KeyStoreAndPassword> tryReadKeystoreFile(AthenzService configserverIdentity, - Duration updatePeriod) { - Optional<X509Certificate> certificate = SiaUtils.readCertificateFile(VESPA_SIA_DIRECTORY, configserverIdentity); - if (!certificate.isPresent()) return Optional.empty(); - Optional<PrivateKey> privateKey = SiaUtils.readPrivateKeyFile(VESPA_SIA_DIRECTORY, configserverIdentity); - if (!privateKey.isPresent()) return Optional.empty(); - Instant minimumExpiration = Instant.now().plus(updatePeriod).plus(EXPIRATION_MARGIN); - boolean isExpired = certificate.get().getNotAfter().toInstant().isBefore(minimumExpiration); - if (isExpired) return Optional.empty(); - char[] password = generateKeystorePassword(); - KeyStore keyStore = KeyStoreBuilder.withType(KeyStoreType.JKS) - .withKeyEntry(CERTIFICATE_ALIAS, privateKey.get(), password, certificate.get()) - .build(); - return Optional.of(new KeyStoreAndPassword(keyStore, password)); - } - - @Override - public void configure(SslKeyStoreContext sslKeyStoreContext) { - sslKeyStoreContext.updateKeyStore(currentKeyStore.keyStore, new String(currentKeyStore.password)); - scheduler.scheduleAtFixedRate(new AthenzCertificateUpdater(sslKeyStoreContext), - updatePeriod.toDays()/*initial delay*/, - updatePeriod.toDays(), - TimeUnit.DAYS); - } - - @Override - public void deconstruct() { - try { - scheduler.shutdownNow(); - scheduler.awaitTermination(30, TimeUnit.SECONDS); - ztsClient.close(); - } catch (InterruptedException e) { - throw new RuntimeException("Failed to shutdown Athenz certificate updater on time", e); - } - } - - Instant getCertificateExpiry() throws KeyStoreException { - return getCertificateExpiry(currentKeyStore.keyStore); - } - - private static Instant getCertificateExpiry(KeyStore keyStore) throws KeyStoreException { - X509Certificate certificate = (X509Certificate) keyStore.getCertificate(CERTIFICATE_ALIAS); - return certificate.getNotAfter().toInstant(); - } - - private static KeyStoreAndPassword downloadCertificate(AthenzService configserverIdentity, - KeyProvider keyProvider, - ZtsClient ztsClient, - AthenzProviderServiceConfig.Zones zoneConfig) { - PrivateKey privateKey = keyProvider.getPrivateKey(zoneConfig.secretVersion()); - PublicKey publicKey = KeyUtils.extractPublicKey(privateKey); - Identity serviceIdentity = ztsClient.getServiceIdentity(configserverIdentity, - Integer.toString(zoneConfig.secretVersion()), - new KeyPair(publicKey, privateKey), - zoneConfig.certDnsSuffix()); - X509Certificate certificate = serviceIdentity.certificate(); - writeCredentials(configserverIdentity, certificate, serviceIdentity.caCertificates(), privateKey); - Instant expirationTime = certificate.getNotAfter().toInstant(); - Duration expiry = Duration.between(certificate.getNotBefore().toInstant(), expirationTime); - log.log(LogLevel.INFO, String.format("Got Athenz x509 certificate with expiry %s (expires %s)", expiry, expirationTime)); - - char[] keystorePassword = generateKeystorePassword(); - KeyStore keyStore = KeyStoreBuilder.withType(KeyStoreType.JKS) - .withKeyEntry(CERTIFICATE_ALIAS, privateKey, keystorePassword, certificate) - .build(); - return new KeyStoreAndPassword(keyStore, keystorePassword); - } - - private static void writeCredentials(AthenzService configserverIdentity, - X509Certificate certificate, - List<X509Certificate> caCertificates, - PrivateKey privateKey) { - SiaUtils.writeCertificateFile(VESPA_SIA_DIRECTORY, configserverIdentity, certificate); - SiaUtils.writePrivateKeyFile(VESPA_SIA_DIRECTORY, configserverIdentity, privateKey); - } - - private static char[] generateKeystorePassword() { - return UUID.randomUUID().toString().toCharArray(); - } - - private class AthenzCertificateUpdater implements Runnable { - - private final SslKeyStoreContext sslKeyStoreContext; - - AthenzCertificateUpdater(SslKeyStoreContext sslKeyStoreContext) { - this.sslKeyStoreContext = sslKeyStoreContext; - } - - @Override - public void run() { - try { - log.log(LogLevel.INFO, "Updating Athenz certificate from ZTS"); - currentKeyStore = downloadCertificate(configserverIdentity, keyProvider, ztsClient, zoneConfig); - sslKeyStoreContext.updateKeyStore(currentKeyStore.keyStore, new String(currentKeyStore.password)); - log.log(LogLevel.INFO, "Athenz certificate reload successfully completed"); - } catch (Throwable e) { - log.log(LogLevel.ERROR, "Failed to update certificate from ZTS: " + e.getMessage(), e); - } - } - - } - - private static class KeyStoreAndPassword { - final KeyStore keyStore; - final char[] password; - - KeyStoreAndPassword(KeyStore keyStore, char[] password) { - this.keyStore = keyStore; - this.password = password; - } - } -} diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslTrustStoreConfigurator.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslTrustStoreConfigurator.java deleted file mode 100644 index a440f96cc49..00000000000 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslTrustStoreConfigurator.java +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.instanceproviderservice; - -import com.google.inject.Inject; -import com.yahoo.jdisc.http.ssl.SslTrustStoreConfigurator; -import com.yahoo.jdisc.http.ssl.SslTrustStoreContext; -import com.yahoo.vespa.athenz.tls.KeyStoreBuilder; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.hosted.athenz.instanceproviderservice.config.AthenzProviderServiceConfig; - -import java.io.File; -import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.cert.X509Certificate; -import java.time.Instant; - -/** - * Programmatic configuration of configserver's truststore - * - * @author bjorncs - */ -public class AthenzSslTrustStoreConfigurator implements SslTrustStoreConfigurator { - - private static final String CERTIFICATE_ALIAS = "cfgselfsigned"; - - private final KeyStore trustStore; - - @Inject - public AthenzSslTrustStoreConfigurator(AthenzProviderServiceConfig athenzProviderServiceConfig) { - this.trustStore = createTrustStore(athenzProviderServiceConfig); - } - - @Override - public void configure(SslTrustStoreContext sslTrustStoreContext) { - sslTrustStoreContext.updateTrustStore(trustStore); - } - - Instant getTrustStoreExpiry() throws KeyStoreException { - X509Certificate certificate = (X509Certificate) trustStore.getCertificate(CERTIFICATE_ALIAS); - return certificate.getNotAfter().toInstant(); - } - - private static KeyStore createTrustStore(AthenzProviderServiceConfig athenzProviderServiceConfig) { - try { - return KeyStoreBuilder.withType(KeyStoreType.JKS) - .fromFile(new File(athenzProviderServiceConfig.athenzCaTrustStore())) - .build(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - -} diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/CertificateExpiryMetricUpdater.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/CertificateExpiryMetricUpdater.java index 2d80b15c7ec..cd69099ea80 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/CertificateExpiryMetricUpdater.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/CertificateExpiryMetricUpdater.java @@ -1,12 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.athenz.instanceproviderservice; +import com.google.inject.Inject; import com.yahoo.component.AbstractComponent; import com.yahoo.jdisc.Metric; -import com.google.inject.Inject; - -import java.security.KeyStoreException; import java.time.Duration; import java.time.Instant; import java.util.concurrent.Executors; @@ -21,23 +19,18 @@ import java.util.logging.Logger; public class CertificateExpiryMetricUpdater extends AbstractComponent { private static final Duration METRIC_REFRESH_PERIOD = Duration.ofMinutes(5); - private static final String NODE_CA_CERT_METRIC_NAME = "node-ca-cert.expiry.seconds"; private static final String ATHENZ_CONFIGSERVER_CERT_METRIC_NAME = "athenz-configserver-cert.expiry.seconds"; private final Logger logger = Logger.getLogger(CertificateExpiryMetricUpdater.class.getName()); private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(); private final Metric metric; - private final AthenzSslKeyStoreConfigurator keyStoreConfigurator; - private final AthenzSslTrustStoreConfigurator trustStoreConfigurator; + private final ConfigserverSslContextFactoryProvider provider; @Inject public CertificateExpiryMetricUpdater(Metric metric, - AthenzSslKeyStoreConfigurator keyStoreConfigurator, - AthenzSslTrustStoreConfigurator trustStoreConfigurator) { + ConfigserverSslContextFactoryProvider provider) { this.metric = metric; - this.keyStoreConfigurator = keyStoreConfigurator; - this.trustStoreConfigurator = trustStoreConfigurator; - + this.provider = provider; scheduler.scheduleAtFixedRate(this::updateMetrics, 30/*initial delay*/, @@ -56,20 +49,11 @@ public class CertificateExpiryMetricUpdater extends AbstractComponent { } private void updateMetrics() { - Instant now = Instant.now(); - try { - Duration keyStoreExpiry = Duration.between(now, keyStoreConfigurator.getCertificateExpiry()); + Duration keyStoreExpiry = Duration.between(Instant.now(), provider.getCertificateNotAfter()); metric.set(ATHENZ_CONFIGSERVER_CERT_METRIC_NAME, keyStoreExpiry.getSeconds(), null); - } catch (KeyStoreException e) { - logger.log(Level.WARNING, "Failed to update key store expiry metric", e); - } - - try { - Duration trustStoreExpiry = Duration.between(now, trustStoreConfigurator.getTrustStoreExpiry()); - metric.set(NODE_CA_CERT_METRIC_NAME, trustStoreExpiry.getSeconds(), null); - } catch (KeyStoreException e) { - logger.log(Level.WARNING, "Failed to update trust store expiry metric", e); + } catch (Exception e) { + logger.log(Level.WARNING, "Failed to update key store expiry metric: " + e.getMessage(), e); } } } diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ConfigserverSslContextFactoryProvider.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ConfigserverSslContextFactoryProvider.java new file mode 100644 index 00000000000..94df93aaea7 --- /dev/null +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ConfigserverSslContextFactoryProvider.java @@ -0,0 +1,200 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.athenz.instanceproviderservice; + +import com.google.inject.Inject; +import com.yahoo.component.AbstractComponent; +import com.yahoo.config.provision.Zone; +import com.yahoo.jdisc.http.ssl.SslContextFactoryProvider; +import com.yahoo.log.LogLevel; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyUtils; +import com.yahoo.vespa.athenz.api.AthenzService; +import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; +import com.yahoo.vespa.athenz.client.zts.Identity; +import com.yahoo.vespa.athenz.client.zts.ZtsClient; +import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; +import com.yahoo.vespa.athenz.utils.SiaUtils; +import com.yahoo.vespa.defaults.Defaults; +import com.yahoo.vespa.hosted.athenz.instanceproviderservice.config.AthenzProviderServiceConfig; +import org.eclipse.jetty.util.ssl.SslContextFactory; + +import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.GeneralSecurityException; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.cert.X509Certificate; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +import static com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl.Utils.getZoneConfig; + +/** + * Configures the JDisc https connector with the configserver's Athenz provider certificate and private key. + * + * @author bjorncs + */ +public class ConfigserverSslContextFactoryProvider extends AbstractComponent implements SslContextFactoryProvider { + private static final String CERTIFICATE_ALIAS = "athenz"; + private static final Duration EXPIRATION_MARGIN = Duration.ofHours(6); + private static final Path VESPA_SIA_DIRECTORY = Paths.get(Defaults.getDefaults().underVespaHome("var/vespa/sia")); + + private static final Logger log = Logger.getLogger(ConfigserverSslContextFactoryProvider.class.getName()); + + private final SslContextFactory sslContextFactory; + private final ScheduledExecutorService scheduler = + Executors.newSingleThreadScheduledExecutor(runnable -> new Thread(runnable, "configserver-ssl-context-factory-provider")); + private final ZtsClient ztsClient; + private final KeyProvider keyProvider; + private final AthenzProviderServiceConfig.Zones zoneConfig; + private final AthenzService configserverIdentity; + + @Inject + public ConfigserverSslContextFactoryProvider(ServiceIdentityProvider bootstrapIdentity, + KeyProvider keyProvider, + AthenzProviderServiceConfig config, + Zone zone) { + this.zoneConfig = getZoneConfig(config, zone); + this.ztsClient = new DefaultZtsClient(URI.create(zoneConfig.ztsUrl()), bootstrapIdentity); + this.keyProvider = keyProvider; + this.configserverIdentity = new AthenzService(zoneConfig.domain(), zoneConfig.serviceName()); + + Duration updatePeriod = Duration.ofDays(config.updatePeriodDays()); + Path trustStoreFile = Paths.get(config.athenzCaTrustStore()); + this.sslContextFactory = initializeSslContextFactory(keyProvider, trustStoreFile, updatePeriod, configserverIdentity, ztsClient, zoneConfig); + scheduler.scheduleAtFixedRate(new KeystoreUpdater(sslContextFactory), + updatePeriod.toDays()/*initial delay*/, + updatePeriod.toDays(), + TimeUnit.DAYS); + } + + @Override + public SslContextFactory getInstance(String containerId, int port) { + return sslContextFactory; + } + + Instant getCertificateNotAfter() { + try { + X509Certificate certificate = (X509Certificate) sslContextFactory.getKeyStore().getCertificate(CERTIFICATE_ALIAS); + return certificate.getNotAfter().toInstant(); + } catch (GeneralSecurityException e) { + throw new IllegalStateException("Unable to find configserver certificate from keystore: " + e.getMessage(), e); + } + } + + @Override + public void deconstruct() { + try { + scheduler.shutdownNow(); + scheduler.awaitTermination(30, TimeUnit.SECONDS); + ztsClient.close(); + } catch (InterruptedException e) { + throw new RuntimeException("Failed to shutdown Athenz certificate updater on time", e); + } + } + + private static SslContextFactory initializeSslContextFactory(KeyProvider keyProvider, + Path trustStoreFile, + Duration updatePeriod, + AthenzService configserverIdentity, + ZtsClient ztsClient, + AthenzProviderServiceConfig.Zones zoneConfig) { + SslContextFactory factory = new SslContextFactory(); + + // Allow safe TLS_RSA* ciphers + String[] excludedCiphersWithoutTlsRsaExclusion = Arrays.stream(factory.getExcludeCipherSuites()) + .filter(cipher -> !cipher.equals("^TLS_RSA_.*$")) + .toArray(String[]::new); + factory.setExcludeCipherSuites(excludedCiphersWithoutTlsRsaExclusion); + + factory.setWantClientAuth(true); + + KeyStore trustStore = + KeyStoreBuilder.withType(KeyStoreType.JKS) + .fromFile(trustStoreFile) + .build(); + factory.setTrustStore(trustStore); + + KeyStore keyStore = + tryReadKeystoreFile(configserverIdentity, updatePeriod) + .orElseGet(() -> updateKeystore(configserverIdentity, generateKeystorePassword(), keyProvider, ztsClient, zoneConfig)); + factory.setKeyStore(keyStore); + factory.setKeyStorePassword(""); + return factory; + } + + private static Optional<KeyStore> tryReadKeystoreFile(AthenzService configserverIdentity, Duration updatePeriod) { + Optional<X509Certificate> certificate = SiaUtils.readCertificateFile(VESPA_SIA_DIRECTORY, configserverIdentity); + if (!certificate.isPresent()) return Optional.empty(); + Optional<PrivateKey> privateKey = SiaUtils.readPrivateKeyFile(VESPA_SIA_DIRECTORY, configserverIdentity); + if (!privateKey.isPresent()) return Optional.empty(); + Instant minimumExpiration = Instant.now().plus(updatePeriod).plus(EXPIRATION_MARGIN); + boolean isExpired = certificate.get().getNotAfter().toInstant().isBefore(minimumExpiration); + if (isExpired) return Optional.empty(); + KeyStore keyStore = KeyStoreBuilder.withType(KeyStoreType.JKS) + .withKeyEntry(CERTIFICATE_ALIAS, privateKey.get(), certificate.get()) + .build(); + return Optional.of(keyStore); + } + + private static KeyStore updateKeystore(AthenzService configserverIdentity, + char[] keystorePwd, + KeyProvider keyProvider, + ZtsClient ztsClient, + AthenzProviderServiceConfig.Zones zoneConfig) { + PrivateKey privateKey = keyProvider.getPrivateKey(zoneConfig.secretVersion()); + PublicKey publicKey = KeyUtils.extractPublicKey(privateKey); + Identity serviceIdentity = ztsClient.getServiceIdentity(configserverIdentity, + Integer.toString(zoneConfig.secretVersion()), + new KeyPair(publicKey, privateKey), + zoneConfig.certDnsSuffix()); + X509Certificate certificate = serviceIdentity.certificate(); + SiaUtils.writeCertificateFile(VESPA_SIA_DIRECTORY, configserverIdentity, certificate); + SiaUtils.writePrivateKeyFile(VESPA_SIA_DIRECTORY, configserverIdentity, privateKey); + Instant expirationTime = certificate.getNotAfter().toInstant(); + Duration expiry = Duration.between(certificate.getNotBefore().toInstant(), expirationTime); + log.log(LogLevel.INFO, String.format("Got Athenz x509 certificate with expiry %s (expires %s)", expiry, expirationTime)); + return KeyStoreBuilder.withType(KeyStoreType.JKS) + .withKeyEntry(CERTIFICATE_ALIAS, privateKey, keystorePwd, certificate) + .build(); + } + + private static char[] generateKeystorePassword() { + return UUID.randomUUID().toString().toCharArray(); + } + + private class KeystoreUpdater implements Runnable { + final SslContextFactory sslContextFactory; + + KeystoreUpdater(SslContextFactory sslContextFactory) { + this.sslContextFactory = sslContextFactory; + } + + @Override + public void run() { + try { + log.log(LogLevel.INFO, "Updating configserver provider certificate from ZTS"); + char[] keystorePwd = generateKeystorePassword(); + KeyStore keyStore = updateKeystore(configserverIdentity, keystorePwd, keyProvider, ztsClient, zoneConfig); + sslContextFactory.reload(scf -> { + scf.setKeyStore(keyStore); + scf.setKeyStorePassword(new String(keystorePwd)); + }); + log.log(LogLevel.INFO, "Certificate successfully updated"); + } catch (Throwable t) { + log.log(LogLevel.ERROR, "Failed to update certificate from ZTS: " + t.getMessage(), t); + } + } + } +} diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java index 183a52f782c..40003d4ccf3 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java @@ -4,7 +4,7 @@ package com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl; import com.google.inject.Inject; import com.yahoo.config.provision.Zone; import com.yahoo.container.jdisc.secretstore.SecretStore; -import com.yahoo.vespa.athenz.tls.KeyUtils; +import com.yahoo.security.KeyUtils; import com.yahoo.vespa.hosted.athenz.instanceproviderservice.KeyProvider; import com.yahoo.vespa.hosted.athenz.instanceproviderservice.config.AthenzProviderServiceConfig; diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AutoGeneratedKeyProvider.java b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AutoGeneratedKeyProvider.java index ca6b5529b08..74e9b02e150 100644 --- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AutoGeneratedKeyProvider.java +++ b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AutoGeneratedKeyProvider.java @@ -1,6 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.athenz.instanceproviderservice; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; + import java.security.KeyPair; import java.security.KeyPairGenerator; import java.security.NoSuchAlgorithmException; @@ -15,13 +18,7 @@ public class AutoGeneratedKeyProvider implements KeyProvider { private final KeyPair keyPair; public AutoGeneratedKeyProvider() { - try { - KeyPairGenerator rsa = KeyPairGenerator.getInstance("RSA"); - rsa.initialize(2048); - keyPair = rsa.genKeyPair(); - } catch (NoSuchAlgorithmException e) { - throw new RuntimeException(e); - } + keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 2048); } @Override diff --git a/bundle-plugin-test/src/test/java/com/yahoo/BundleIT.java b/bundle-plugin-test/src/test/java/com/yahoo/BundleIT.java index ccee4844b49..38ca08ecff1 100644 --- a/bundle-plugin-test/src/test/java/com/yahoo/BundleIT.java +++ b/bundle-plugin-test/src/test/java/com/yahoo/BundleIT.java @@ -2,6 +2,7 @@ package com.yahoo; import com.yahoo.osgi.maven.ProjectBundleClassPaths; +import com.yahoo.vespa.config.VespaVersion; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; @@ -64,9 +65,12 @@ public class BundleIT { } @Test - @Ignore // TODO Vespa 7: Should we fix this? Why not? - public void require_that_bundle_version_matches_pom_version() { - assertThat(mainAttributes.getValue("Bundle-Version"), is("5.1.0")); + public void require_that_bundle_version_is_added_to_manifest() { + String bundleVersion = mainAttributes.getValue("Bundle-Version"); + + // Because of snapshot builds, we can only verify the major version. + int majorBundleVersion = Integer.valueOf(bundleVersion.substring(0, bundleVersion.indexOf('.'))); + assertThat(majorBundleVersion, is(VespaVersion.major)); } @Test diff --git a/config-lib/src/main/java/com/yahoo/config/PathNode.java b/config-lib/src/main/java/com/yahoo/config/PathNode.java index b63dad4d1a7..9d73b5e23c2 100644 --- a/config-lib/src/main/java/com/yahoo/config/PathNode.java +++ b/config-lib/src/main/java/com/yahoo/config/PathNode.java @@ -14,7 +14,6 @@ import java.util.Map; * Represents a 'path' in a {@link ConfigInstance}, usually a filename. * * @author gjoranv - * @since 5.1.30 */ public class PathNode extends LeafNode<Path> { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 8c7398b3dde..afd33da369f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -13,12 +13,16 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; +import java.util.ArrayDeque; import java.util.Collection; import java.util.Collections; +import java.util.Deque; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Stack; +import java.util.stream.Collectors; /** * A context which only contains type information. @@ -26,21 +30,29 @@ import java.util.Optional; * query, attribute or constant features, as we do not have information about which such * features exist (but we know those that exist are doubles). * + * This is not multithread safe. + * * @author bratseth */ public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext<Reference> { private final Map<Reference, TensorType> featureTypes = new HashMap<>(); - public MapEvaluationTypeContext(Collection<ExpressionFunction> functions) { + /** For invocation loop detection */ + private final Deque<Reference> currentResolutionCallStack; + + MapEvaluationTypeContext(Collection<ExpressionFunction> functions) { super(functions); + this.currentResolutionCallStack = new ArrayDeque<>(); } - public MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, - Map<String, String> bindings, - Map<Reference, TensorType> featureTypes) { + private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, + Map<String, String> bindings, + Map<Reference, TensorType> featureTypes, + Deque<Reference> currentResolutionCallStack) { super(functions, bindings); this.featureTypes.putAll(featureTypes); + this.currentResolutionCallStack = currentResolutionCallStack; } public void setType(Reference reference, TensorType type) { @@ -54,6 +66,11 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement @Override public TensorType getType(Reference reference) { + if (currentResolutionCallStack.contains(reference)) + throw new IllegalArgumentException("Invocation loop: " + + currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) + + " -> " + reference); + // A reference to a macro argument? Optional<String> binding = boundIdentifier(reference); if (binding.isPresent()) { @@ -61,36 +78,42 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement // This is not pretty, but changing to bind expressions rather // than their string values requires deeper changes return new RankingExpression(binding.get()).type(this); - } - catch (ParseException e) { + } catch (ParseException e) { throw new IllegalArgumentException(e); } } - // A reference to an attribute, query or constant feature? - if (FeatureNames.isSimpleFeature(reference)) { - // The argument may be a local identifier bound to the actual value - String argument = reference.simpleArgument().get(); - reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument)); - return featureTypes.getOrDefault(reference, defaultTypeOf(reference)); - } + try { + currentResolutionCallStack.addLast(reference); - // A reference to a function? - Optional<ExpressionFunction> function = functionInvocation(reference); - if (function.isPresent()) { - return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); - } + // A reference to an attribute, query or constant feature? + if (FeatureNames.isSimpleFeature(reference)) { + // The argument may be a local identifier bound to the actual value + String argument = reference.simpleArgument().get(); + reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument)); + return featureTypes.getOrDefault(reference, defaultTypeOf(reference)); + } - // A reference to a feature which returns a tensor? - Optional<TensorType> featureTensorType = tensorFeatureType(reference); - if (featureTensorType.isPresent()) { - return featureTensorType.get(); - } + // A reference to a function? + Optional<ExpressionFunction> function = functionInvocation(reference); + if (function.isPresent()) { + return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); + } + + // A reference to a feature which returns a tensor? + Optional<TensorType> featureTensorType = tensorFeatureType(reference); + if (featureTensorType.isPresent()) { + return featureTensorType.get(); + } - // We do not know what this is - since we do not have complete knowledge abut the match features - // in Java we must assume this is a match feature and return the double type - which is the type of all - // all match features - return TensorType.empty; + // We do not know what this is - since we do not have complete knowledge abut the match features + // in Java we must assume this is a match feature and return the double type - which is the type of all + // all match features + return TensorType.empty; + } + finally { + currentResolutionCallStack.removeLast(); + } } /** @@ -173,7 +196,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement @Override public MapEvaluationTypeContext withBindings(Map<String, String> bindings) { if (bindings.isEmpty() && this.bindings.isEmpty()) return this; - return new MapEvaluationTypeContext(functions(), bindings, featureTypes); + return new MapEvaluationTypeContext(functions(), bindings, featureTypes, currentResolutionCallStack); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 9d6a1351724..b7e1f9d4538 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -768,16 +768,13 @@ public class RankProfile implements Serializable, Cloneable { Map<String, Macro> inlineMacros, ExpressionTransforms expressionTransforms) { if (expression == null) return null; - Map<String, String> rankPropertiesOutput = new HashMap<>(); - RankProfileTransformContext context = new RankProfileTransformContext(this, queryProfiles, importedModels, constants, - inlineMacros, - rankPropertiesOutput); + inlineMacros); expression = expressionTransforms.transform(expression, context); - for (Map.Entry<String, String> rankProperty : rankPropertiesOutput.entrySet()) { + for (Map.Entry<String, String> rankProperty : context.rankProperties().entrySet()) { addRankProperty(rankProperty.getKey(), rankProperty.getValue()); } return expression; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java index a2bdc6834c9..7b7265e02ae 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java @@ -1,14 +1,21 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.config.FileReference; import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.model.AbstractService; +import com.yahoo.vespa.model.utils.FileSender; +import java.util.Collection; import java.util.Objects; /** - * Represents a global ranking constant + * A global ranking constant distributed using file distribution. + * Ranking constants must be sent to some services to be useful - this is done + * by calling the sentTo method during the prepare phase of building models. * * @author arnej + * @author bratseth */ public class RankingConstant { @@ -49,14 +56,16 @@ public class RankingConstant { this.pathType = PathType.URI; } - /** - * Set the internally generated reference to this file used to identify this instance of the file for - * file distribution. - */ - public void setFileReference(String fileReference) { this.fileReference = fileReference; } - public void setType(TensorType tensorType) { this.tensorType = tensorType; } + /** Initiate sending of this constant to some services over file distribution */ + public void sendTo(Collection<? extends AbstractService> services) { + FileReference reference = (pathType == RankingConstant.PathType.FILE) + ? FileSender.sendFileToServices(path, services) + : FileSender.sendUriToServices(path, services); + this.fileReference = reference.value(); + } + public String getName() { return name; } public String getFileName() { return path; } public String getUri() { return path; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java index 164cb7f808e..e354c52092f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java @@ -1,6 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.config.FileReference; +import com.yahoo.vespa.model.AbstractService; +import com.yahoo.vespa.model.utils.FileSender; + +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -33,4 +38,9 @@ public class RankingConstants { return Collections.unmodifiableMap(constants); } + /** Initiate sending of these constants to some services over file distribution */ + public void sendTo(Collection<? extends AbstractService> services) { + constants.values().forEach(constant -> constant.sendTo(services)); + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 4af26b72817..9a00ee5bbd0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -94,7 +94,7 @@ public class DerivedConfiguration { summaries = new Summaries(search, deployLogger); summaryMap = new SummaryMap(search, summaries); juniperrc = new Juniperrc(search); - rankProfileList = new RankProfileList(search, attributeFields, rankProfileRegistry, queryProfiles, importedModels); + rankProfileList = new RankProfileList(search, search.rankingConstants(), attributeFields, rankProfileRegistry, queryProfiles, importedModels); indexingScript = new IndexingScript(search); indexInfo = new IndexInfo(search); indexSchema = new IndexSchema(search); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 10881ab9ce0..fcbfb47c597 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -3,24 +3,35 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.model.AbstractService; + +import java.util.Collection; import java.util.Map; +import java.util.logging.Logger; /** * The derived rank profiles of a search definition * * @author bratseth */ -public class RankProfileList extends Derived implements RankProfilesConfig.Producer { +public class RankProfileList extends Derived implements RankProfilesConfig.Producer, RankingConstantsConfig.Producer { + + private static final Logger log = Logger.getLogger(RankProfileList.class.getName()); private final Map<String, RawRankProfile> rankProfiles = new java.util.LinkedHashMap<>(); + private final RankingConstants rankingConstants; public static RankProfileList empty = new RankProfileList(); private RankProfileList() { + this.rankingConstants = new RankingConstants(); } /** @@ -30,11 +41,13 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ * @param attributeFields the attribute fields to create a ranking for */ public RankProfileList(Search search, + RankingConstants rankingConstants, AttributeFields attributeFields, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, ImportedModels importedModels) { setName(search == null ? "default" : search.getName()); + this.rankingConstants = rankingConstants; deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields); } @@ -68,6 +81,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ return rankProfiles.get(name); } + public void sendConstantsTo(Collection<? extends AbstractService> services) { + rankingConstants.sendTo(services); + } + @Override public String getDerivedName() { return "rank-profiles"; } @@ -78,4 +95,17 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } + @Override + public void getConfig(RankingConstantsConfig.Builder builder) { + for (RankingConstant constant : rankingConstants.asMap().values()) { + if ("".equals(constant.getFileReference())) + log.warning("Illegal file reference " + constant); // Let tests pass ... we should find a better way + else + builder.constant(new RankingConstantsConfig.Constant.Builder() + .name(constant.getName()) + .fileref(constant.getFileReference()) + .type(constant.getType())); + } + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 0d104a97698..43cc2fad285 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -316,18 +316,18 @@ public class RawRankProfile implements RankProfilesConfig.Producer { } catch (ParseException e) { throw new IllegalArgumentException("Could not parse second phase expression", e); } - continue; } - if ("rankingExpression(secondphase).rankingScript".equals(property.getName())) { + else if ("rankingExpression(secondphase).rankingScript".equals(property.getName())) { try { secondPhaseRanking = new RankingExpression(property.getValue()); } catch (ParseException e) { throw new IllegalArgumentException("Could not parse second phase expression", e); } - continue; } - properties.put(property.getName() + ".part" + i, property.getValue()); - i++; + else { + properties.put(property.getName() + ".part" + i, property.getValue()); + i++; + } } properties.putAll(deriveRankingPhaseRankProperties(firstPhaseRanking, "firstphase")); properties.putAll(deriveRankingPhaseRankProperties(secondPhaseRanking, "secondphase")); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java index ee38c518d6b..eb76446c045 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java @@ -2,7 +2,6 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.searchdefinition.FeatureNames; -import com.yahoo.searchdefinition.MapEvaluationTypeContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -60,8 +59,8 @@ public class ConstantTensorTransformer extends ExpressionTransformer<RankProfile TensorValue tensorValue = (TensorValue)value; String featureName = CONSTANT + "(" + node.getName() + ")"; String tensorType = tensorValue.asTensor().type().toString(); - context.rankPropertiesOutput().put(featureName + ".value", tensorValue.toString()); - context.rankPropertiesOutput().put(featureName + ".type", tensorType); + context.rankProperties().put(featureName + ".value", tensorValue.toString()); + context.rankProperties().put(featureName + ".type", tensorType); // TODO: This allows us to reference constant "a" as "a" instead of "constant(a)", but we shouldn't allow that return new ReferenceNode(CONSTANT, Arrays.asList(new NameNode(node.getName())), null); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 229ae0ebaaf..8634d51c418 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -1,5 +1,4 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.path.Path; @@ -8,12 +7,12 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; -import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; -import java.util.Optional; /** * Replaces instances of the onnx(model-path, output) @@ -43,9 +42,9 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans try { // TODO: Put modelPath in FeatureArguments instead - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); + convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, true, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); } catch (IllegalArgumentException | UncheckedIOException e) { @@ -53,14 +52,14 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans } } - private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + private FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("An onnx node must take an argument pointing to " + "the onnx model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); - return new ConvertedModel.FeatureArguments(arguments); + return new FeatureArguments(arguments); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java index c7b4e85d74e..40c3b997daa 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; +import java.util.HashMap; import java.util.Map; /** @@ -20,26 +21,24 @@ public class RankProfileTransformContext extends TransformContext { private final QueryProfileRegistry queryProfiles; private final ImportedModels importedModels; private final Map<String, RankProfile.Macro> inlineMacros; - private final Map<String, String> rankPropertiesOutput; + private final Map<String, String> rankProperties = new HashMap<>(); public RankProfileTransformContext(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, Map<String, Value> constants, - Map<String, RankProfile.Macro> inlineMacros, - Map<String, String> rankPropertiesOutput) { + Map<String, RankProfile.Macro> inlineMacros) { super(constants); this.rankProfile = rankProfile; this.queryProfiles = queryProfiles; this.importedModels = importedModels; this.inlineMacros = inlineMacros; - this.rankPropertiesOutput = rankPropertiesOutput; } public RankProfile rankProfile() { return rankProfile; } public QueryProfileRegistry queryProfiles() { return queryProfiles; } public ImportedModels importedModels() { return importedModels; } public Map<String, RankProfile.Macro> inlineMacros() { return inlineMacros; } - public Map<String, String> rankPropertiesOutput() { return rankPropertiesOutput; } + public Map<String, String> rankProperties() { return rankProperties; } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index bcb8ef1521d..5139d041f00 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -7,8 +7,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; -import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; @@ -39,9 +40,9 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); + convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, false, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); } catch (IllegalArgumentException | UncheckedIOException e) { @@ -49,14 +50,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + private FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + "the tensorflow model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); - return new ConvertedModel.FeatureArguments(arguments); + return new FeatureArguments(arguments); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java index b4a5069b9d6..f21248b6d74 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -2,13 +2,13 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.XGBoostImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; import java.io.UncheckedIOException; import java.util.HashMap; @@ -41,20 +41,20 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr if ( ! feature.getName().equals("xgboost")) return feature; try { - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); + convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, true, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e); } } - private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + private FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.size() != 1) throw new IllegalArgumentException("An xgboost node must take a single argument pointing to " + "the xgboost model directory under [application]/models"); - return new ConvertedModel.FeatureArguments(arguments); + return new FeatureArguments(arguments); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 3e9d188670e..282e5a29962 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -2,7 +2,6 @@ package com.yahoo.vespa.model; import com.google.common.collect.ImmutableList; -import com.yahoo.collections.Pair; import com.yahoo.config.ConfigBuilder; import com.yahoo.config.ConfigInstance; import com.yahoo.config.ConfigInstance.Builder; @@ -33,7 +32,7 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RankProfileList; -import com.yahoo.searchdefinition.expressiontransforms.ConvertedModel; +import com.yahoo.vespa.model.ml.ConvertedModel; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; @@ -54,6 +53,7 @@ import com.yahoo.vespa.model.content.cluster.ContentCluster; import com.yahoo.vespa.model.filedistribution.FileDistributionConfigProducer; import com.yahoo.vespa.model.filedistribution.FileDistributor; import com.yahoo.vespa.model.generic.service.ServiceCluster; +import com.yahoo.vespa.model.ml.ModelName; import com.yahoo.vespa.model.routing.Routing; import com.yahoo.vespa.model.search.AbstractSearchCluster; import com.yahoo.vespa.model.utils.internal.ReflectionUtil; @@ -169,6 +169,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri deployState.rankProfileRegistry(), deployState.getQueryProfiles().getRegistry()); this.rankProfileList = new RankProfileList(null, // null search -> global + rankingConstants, AttributeFields.empty, deployState.rankProfileRegistry(), deployState.getQueryProfiles().getRegistry(), @@ -232,7 +233,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri for (ImportedModel model : importedModels.all()) { RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); rankProfileRegistry.add(profile); - ConvertedModel convertedModel = ConvertedModel.fromSource(model.name(), model.name(), profile, queryProfiles, model); + ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), + model.name(), profile, queryProfiles, model); for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue()); } @@ -244,7 +246,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri String modelName = generatedModelDir.getPath().last(); RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry); rankProfileRegistry.add(profile); - ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, modelName, profile); + ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue()); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java index e8985b094ac..73d77406700 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java @@ -135,6 +135,8 @@ public class VespaMetricSet { metrics.add(new Metric("http.status.3xx.rate")); metrics.add(new Metric("http.status.4xx.rate")); metrics.add(new Metric("http.status.5xx.rate")); + metrics.add(new Metric("http.status.401.rate")); + metrics.add(new Metric("http.status.403.rate")); metrics.add(new Metric("jdisc.http.request.uri_length.average")); metrics.add(new Metric("jdisc.http.request.uri_length.max")); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV2Builder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV2Builder.java index d67cb0c29c3..9deb03495f2 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV2Builder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV2Builder.java @@ -90,7 +90,10 @@ public class DomAdminV2Builder extends DomAdminBuilderBase { if (standaloneZooKeeper) { parent = new ClusterControllerCluster(parent, "standalone"); } - ContainerCluster cluster = new ContainerCluster(parent, "cluster-controllers", "cluster-controllers", new ClusterControllerClusterVerifier(), RankProfileList.empty); + ContainerCluster cluster = new ContainerCluster(parent, + "cluster-controllers", + "cluster-controllers", + new ClusterControllerClusterVerifier()); ContainerModelBuilder.addDefaultHandler_legacyBuilder(cluster); List<Container> containers = new ArrayList<>(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java index e34d490afe1..06450698a14 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java @@ -1,18 +1,21 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.builder.xml.dom; -import com.yahoo.component.Version; import com.yahoo.config.model.ConfigModelContext; import com.yahoo.config.model.api.ConfigServerSpec; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.config.provision.SystemName; +import com.yahoo.log.LogLevel; import com.yahoo.vespa.model.HostResource; import com.yahoo.vespa.model.HostSystem; import com.yahoo.vespa.model.admin.Admin; import com.yahoo.vespa.model.admin.Logserver; import com.yahoo.vespa.model.admin.Slobrok; import com.yahoo.vespa.model.container.Container; +import com.yahoo.vespa.model.container.ContainerCluster; import com.yahoo.vespa.model.container.ContainerModel; +import com.yahoo.vespa.model.container.component.Handler; import org.w3c.dom.Element; import java.util.ArrayList; @@ -55,7 +58,7 @@ public class DomAdminV4Builder extends DomAdminBuilderBase { NodesSpecification.optionalDedicatedFromParent(adminElement.getChild("logservers"), context); assignSlobroks(requestedSlobroks.orElse(NodesSpecification.nonDedicated(3, context)), admin); - assignLogserver(requestedLogservers.orElse(NodesSpecification.nonDedicated(1, context)), admin); + assignLogserver(requestedLogservers.orElse(createNodesSpecificationForLogserver()), admin); addLogForwarders(adminElement.getChild("logforwarding"), admin); } @@ -73,14 +76,57 @@ public class DomAdminV4Builder extends DomAdminBuilderBase { if (nodesSpecification.count() > 1) throw new IllegalArgumentException("You can only request a single log server"); if (nodesSpecification.isDedicated()) { - createLogserver(admin, allocateHosts(admin.getHostSystem(), "logserver", nodesSpecification)); - } - else { - if (containerModels.iterator().hasNext()) - createLogserver(admin, sortedContainerHostsFrom(containerModels.iterator().next(), nodesSpecification.count(), false)); + Collection<HostResource> hosts = allocateHosts(admin.getHostSystem(), "logserver", nodesSpecification); + if (hosts.isEmpty()) return; // No log server can be created (and none is needed) + + Logserver logserver = createLogserver(admin, hosts); + createAdditionalContainerOnLogserverHost(admin, logserver.getHostResource()); + } else if (containerModels.iterator().hasNext()) { + List<HostResource> hosts = sortedContainerHostsFrom(containerModels.iterator().next(), nodesSpecification.count(), false); + if (hosts.isEmpty()) return; // No log server can be created (and none is needed) + + createLogserver(admin, hosts); + } else { + context.getDeployLogger().log(LogLevel.INFO, "No container host available to use for running logserver"); } } + private NodesSpecification createNodesSpecificationForLogserver() { + // TODO: Enable for main system as well + //if (context.getDeployState().isHosted() && context.getDeployState().zone().system() == SystemName.cd) + // return NodesSpecification.dedicated(1, context); + //else + return NodesSpecification.nonDedicated(1, context); + } + + // Creates a container cluster 'logserver-cluster' with 1 container on logserver host + // for setting up a handler for getting logs from logserver + private void createAdditionalContainerOnLogserverHost(Admin admin, HostResource hostResource) { + ContainerCluster logServerCluster = new ContainerCluster(admin, "logserver-cluster", "logserver-cluster"); + ContainerModel logserverClusterModel = new ContainerModel(context.withParent(admin).withId(logServerCluster.getSubId())); + + // Add base handlers and the log handler + logServerCluster.addMetricStateHandler(); + logServerCluster.addApplicationStatusHandler(); + logServerCluster.addStatisticsHandler(); + logServerCluster.addDefaultRootHandler(); + addLogHandler(logServerCluster); + + logserverClusterModel.setCluster(logServerCluster); + + Container container = new Container(logServerCluster, "logserver-container", 0); + container.setHostResource(hostResource); + container.initService(); + logServerCluster.addContainer(container); + admin.addAndInitializeService(hostResource, container); + } + + private void addLogHandler(ContainerCluster cluster) { + Handler<?> logHandler = Handler.fromClassName("com.yahoo.container.handler.LogHandler"); + logHandler.addServerBindings("http://*/logs", "https://*/logs"); + cluster.addComponent(logHandler); + } + private Collection<HostResource> allocateHosts(HostSystem hostSystem, String clusterId, NodesSpecification nodesSpecification) { return nodesSpecification.provision(hostSystem, ClusterSpec.Type.admin, @@ -148,12 +194,12 @@ public class DomAdminV4Builder extends DomAdminBuilderBase { return HostResource.pickHosts(hosts, count, 1); } - private void createLogserver(Admin admin, Collection<HostResource> hosts) { - if (hosts.isEmpty()) return; // No log server can be created (and none is needed) + private Logserver createLogserver(Admin admin, Collection<HostResource> hosts) { Logserver logserver = new Logserver(admin); logserver.setHostResource(hosts.iterator().next()); admin.setLogserver(logserver); logserver.initService(); + return logserver; } private void createSlobroks(Admin admin, Collection<HostResource> hosts) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/NodesSpecification.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/NodesSpecification.java index 94359e8672e..0ab0c0b6d4f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/NodesSpecification.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/NodesSpecification.java @@ -123,6 +123,19 @@ public class NodesSpecification { Optional.empty()); } + /** Returns a requirement from <code>count</code> dedicated nodes in one group */ + public static NodesSpecification dedicated(int count, ConfigModelContext context) { + return new NodesSpecification(true, + count, + 1, + context.getDeployState().getWantedNodeVespaVersion(), + false, + ! context.getDeployState().getProperties().isBootstrap(), + false, + Optional.empty(), + Optional.empty()); + } + /** * Returns whether this requires dedicated nodes. * Otherwise the model encountering this request should reuse nodes requested for other purposes whenever possible. diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java index 8c6c13d810f..f0724306e9c 100755 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.container; -import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.cloud.config.ClusterInfoConfig; import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.cloud.config.RoutingProviderConfig; @@ -41,10 +40,8 @@ import com.yahoo.search.config.IndexInfoConfig; import com.yahoo.search.config.QrStartConfig; import com.yahoo.search.pagetemplates.PageTemplatesConfig; import com.yahoo.search.query.profile.config.QueryProfilesConfig; -import com.yahoo.searchdefinition.RankProfileRegistry; -import com.yahoo.searchdefinition.derived.AttributeFields; -import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.model.PortsMeta; import com.yahoo.vespa.model.Service; @@ -66,11 +63,9 @@ import com.yahoo.vespa.model.container.docproc.ContainerDocproc; import com.yahoo.vespa.model.container.docproc.DocprocChains; import com.yahoo.vespa.model.container.http.Http; import com.yahoo.vespa.model.container.jersey.Jersey2Servlet; -import com.yahoo.vespa.model.container.jersey.JerseyHandler; import com.yahoo.vespa.model.container.jersey.RestApi; import com.yahoo.vespa.model.container.processing.ProcessingChains; import com.yahoo.vespa.model.container.search.ContainerSearch; -import com.yahoo.vespa.model.container.search.QueryProfiles; import com.yahoo.vespa.model.container.search.searchchain.SearchChains; import com.yahoo.vespa.model.content.Content; import com.yahoo.vespa.model.search.AbstractSearchCluster; @@ -79,7 +74,6 @@ import com.yahoo.vespaclient.config.FeederConfig; import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; - import java.nio.file.Path; import java.util.ArrayList; import java.util.Collection; @@ -91,7 +85,6 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -129,7 +122,8 @@ public final class ContainerCluster RoutingProviderConfig.Producer, ConfigserverConfig.Producer, ThreadpoolConfig.Producer, - RankProfilesConfig.Producer + RankProfilesConfig.Producer, + RankingConstantsConfig.Producer { @@ -161,6 +155,7 @@ public final class ContainerCluster private ContainerDocproc containerDocproc; private ContainerDocumentApi containerDocumentApi; private SecretStore secretStore; + private ContainerModelEvaluation modelEvaluation; private MbusParams mbusParams; private boolean rpcServerEnabled = true; @@ -177,9 +172,6 @@ public final class ContainerCluster private final ContainerClusterVerifier clusterVerifier; private final boolean isHostedVespa; - /** Global rank profiles, aka models */ - private final RankProfileList rankProfileList; - private Map<String, String> concreteDocumentTypes = new LinkedHashMap<>(); private MetricDefaultsConfig.Factory.Enum defaultMetricConsumerFactory; @@ -204,30 +196,16 @@ public final class ContainerCluster } } - /** - * Creates a container cluster - * - * @param rankProfileList the list ofd global rank profiles containing models that should be available in - * container clusters - */ - public ContainerCluster(AbstractConfigProducer<?> parent, - String subId, - String name, - RankProfileList rankProfileList) { - this(parent, subId, name, new AcceptAllVerifier(), rankProfileList); + /** Creates a container cluster */ + public ContainerCluster(AbstractConfigProducer<?> parent, String subId, String name) { + this(parent, subId, name, new AcceptAllVerifier()); } - /** - * Creates a container cluster - * - * @param rankProfileList the list ofd global rank profiles containing models that should be available in - * container clusters - */ + /** Creates a container cluster */ public ContainerCluster(AbstractConfigProducer<?> parent, String subId, String name, - ContainerClusterVerifier verifier, - RankProfileList rankProfileList) { + ContainerClusterVerifier verifier) { super(parent, subId); this.clusterVerifier = verifier; this.name = name; @@ -237,14 +215,12 @@ public final class ContainerCluster componentGroup = new ComponentGroup<>(this, "component"); restApiGroup = new ConfigProducerGroup<>(this, "rest-api"); servletGroup = new ConfigProducerGroup<>(this, "servlet"); - this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null"); addComponent(new StatisticsComponent()); addSimpleComponent(AccessLog.class); // TODO better modelling addSimpleComponent(ThreadPoolProvider.class); addSimpleComponent(com.yahoo.concurrent.classlock.ClassLocking.class); - addSimpleComponent(ModelsEvaluator.class.getName(), null, "model-evaluation"); addSimpleComponent("com.yahoo.jdisc.http.filter.SecurityFilterInvoker"); addSimpleComponent(SIMPLE_LINGUISTICS_PROVIDER); addSimpleComponent("com.yahoo.container.jdisc.SecretStoreProvider"); @@ -364,6 +340,8 @@ public final class ContainerCluster public void prepare() { addAndSendApplicationBundles(); + if (modelEvaluation != null) + modelEvaluation.prepare(containers); sendUserConfiguredFiles(); setApplicationMetaData(); for (RestApi restApi : restApiGroup.getComponents()) @@ -460,6 +438,10 @@ public final class ContainerCluster this.containerSearch = containerSearch; } + public void setModelEvaluation(ContainerModelEvaluation modelEvaluation) { + this.modelEvaluation = modelEvaluation; + } + public void setHttp(Http http) { this.http = http; addChild(http); @@ -554,18 +536,14 @@ public final class ContainerCluster } @Override - public final void getConfig(ComponentsConfig.Builder builder) { + public void getConfig(ComponentsConfig.Builder builder) { builder.components.addAll(ComponentsConfigGenerator.generate(getAllComponents())); builder.components(new ComponentsConfig.Components.Builder().id("com.yahoo.container.core.config.HandlersConfigurerDi$RegistriesHack")); } @Override - public final void getConfig(JdiscBindingsConfig.Builder builder) { + public void getConfig(JdiscBindingsConfig.Builder builder) { builder.handlers.putAll(DiscBindingsConfigGenerator.generate(getHandlers())); - - allJersey1Handlers().forEach(handler -> - builder.handlers.putAll(DiscBindingsConfigGenerator.generate(handler)) - ); } @Override @@ -573,10 +551,6 @@ public final class ContainerCluster clusterVerifier.getConfig(builder); } - private Stream<JerseyHandler> allJersey1Handlers() { - return restApiGroup.getComponents().stream().flatMap(streamOf(RestApi::getJersey1Handler)); - } - @Override public void getConfig(ServletPathsConfig.Builder builder) { allServlets().forEach(servlet -> @@ -591,14 +565,7 @@ public final class ContainerCluster } private Stream<Jersey2Servlet> allJersey2Servlets() { - return restApiGroup.getComponents().stream().flatMap(streamOf(RestApi::getJersey2Servlet)); - } - - private <T, R> Function<T, Stream<R>> streamOf(Function<T, Optional<R>> f) { - return t -> - f.apply(t). - <Stream<R>>map(Stream::of). - orElse(Stream.empty()); + return restApiGroup.getComponents().stream().map(RestApi::getJersey2Servlet); } @Override @@ -670,47 +637,37 @@ public final class ContainerCluster @Override public void getConfig(DocprocConfig.Builder builder) { - if (containerDocproc != null) { - containerDocproc.getConfig(builder); - } + if (containerDocproc != null) containerDocproc.getConfig(builder); } @Override public void getConfig(PageTemplatesConfig.Builder builder) { - if (containerSearch != null) { - containerSearch.getConfig(builder); - } + if (containerSearch != null) containerSearch.getConfig(builder); } @Override public void getConfig(SemanticRulesConfig.Builder builder) { - if (containerSearch != null) { - containerSearch.getConfig(builder); - } + if (containerSearch != null) containerSearch.getConfig(builder); } @Override public void getConfig(QueryProfilesConfig.Builder builder) { - if (containerSearch != null) { - containerSearch.getConfig(builder); - } + if (containerSearch != null) containerSearch.getConfig(builder); } @Override public void getConfig(SchemamappingConfig.Builder builder) { - if (containerDocproc!=null) containerDocproc.getConfig(builder); + if (containerDocproc != null) containerDocproc.getConfig(builder); } @Override public void getConfig(IndexInfoConfig.Builder builder) { - if (containerSearch!=null) containerSearch.getConfig(builder); + if (containerSearch != null) containerSearch.getConfig(builder); } @Override public void getConfig(FeederConfig.Builder builder) { - if (containerDocumentApi != null) { - containerDocumentApi.getConfig(builder); - } + if (containerDocumentApi != null) containerDocumentApi.getConfig(builder); } @Override @@ -729,7 +686,12 @@ public final class ContainerCluster @Override public void getConfig(RankProfilesConfig.Builder builder) { - rankProfileList.getConfig(builder); + if (modelEvaluation != null) modelEvaluation.getConfig(builder); + } + + @Override + public void getConfig(RankingConstantsConfig.Builder builder) { + if (modelEvaluation != null) modelEvaluation.getConfig(builder); } public void setMbusParams(MbusParams mbusParams) { @@ -737,8 +699,7 @@ public final class ContainerCluster } public void initialize(Map<String, AbstractSearchCluster> clusterMap) { - if (containerSearch != null) - containerSearch.connectSearchClusters(clusterMap); + if (containerSearch != null) containerSearch.connectSearchClusters(clusterMap); } public void addDefaultSearchAccessLog() { @@ -756,9 +717,7 @@ public final class ContainerCluster @Override public void getConfig(MetricDefaultsConfig.Builder builder) { - if (defaultMetricConsumerFactory != null) { - builder.factory(defaultMetricConsumerFactory); - } + if (defaultMetricConsumerFactory != null) builder.factory(defaultMetricConsumerFactory); } @Override @@ -868,4 +827,5 @@ public final class ContainerCluster this.containerCoreMemory = containerCoreMemory; } } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java new file mode 100644 index 00000000000..09990c7b9de --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java @@ -0,0 +1,41 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container; + +import ai.vespa.models.evaluation.ModelsEvaluator; +import com.yahoo.searchdefinition.derived.RankProfileList; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; + +import java.util.List; +import java.util.Objects; + +/** + * Configuration of components for stateless model evaluation + * + * @author bratseth + */ +public class ContainerModelEvaluation implements RankProfilesConfig.Producer, RankingConstantsConfig.Producer { + + /** Global rank profiles, aka models */ + private final RankProfileList rankProfileList; + + public ContainerModelEvaluation(ContainerCluster cluster, RankProfileList rankProfileList) { + this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null"); + cluster.addSimpleComponent(ModelsEvaluator.class.getName(), null, "model-evaluation"); + } + + public void prepare(List<Container> containers) { + rankProfileList.sendConstantsTo(containers); + } + + @Override + public void getConfig(RankProfilesConfig.Builder builder) { + rankProfileList.getConfig(builder); + } + + @Override + public void getConfig(RankingConstantsConfig.Builder builder) { + rankProfileList.getConfig(builder); + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ConnectorFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ConnectorFactory.java index 33f5edded3c..a9d3ec0e5a2 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ConnectorFactory.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ConnectorFactory.java @@ -4,11 +4,10 @@ package com.yahoo.vespa.model.container.http; import com.yahoo.component.ComponentId; import com.yahoo.container.bundle.BundleInstantiationSpecification; import com.yahoo.jdisc.http.ConnectorConfig; -import com.yahoo.jdisc.http.ssl.DefaultSslKeyStoreConfigurator; -import com.yahoo.jdisc.http.ssl.DefaultSslTrustStoreConfigurator; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.text.XML; import com.yahoo.vespa.model.container.component.SimpleComponent; +import com.yahoo.vespa.model.container.http.ssl.LegacySslProvider; import org.w3c.dom.Element; import static com.yahoo.component.ComponentSpecification.fromString; @@ -17,22 +16,23 @@ import static com.yahoo.jdisc.http.ConnectorConfig.Ssl.KeyStoreType; /** * @author Einar M R Rosenvinge * @author bjorncs + * @author mortent */ public class ConnectorFactory extends SimpleComponent implements ConnectorConfig.Producer { private final String name; private final int listenPort; private final Element legacyConfig; + private final SimpleComponent sslProviderComponent; public ConnectorFactory(String name, int listenPort) { - this(name, listenPort, null, null, null); + this(name, listenPort, null, new LegacySslProvider(name)); } public ConnectorFactory(String name, int listenPort, Element legacyConfig, - Element sslKeystoreConfigurator, - Element sslTruststoreConfigurator) { + SimpleComponent sslProviderComponent) { super(new ComponentModel( new BundleInstantiationSpecification(new ComponentId(name), fromString("com.yahoo.jdisc.http.server.jetty.ConnectorFactory"), @@ -40,8 +40,9 @@ public class ConnectorFactory extends SimpleComponent implements ConnectorConfig this.name = name; this.listenPort = listenPort; this.legacyConfig = legacyConfig; - addSslKeyStoreConfigurator(name, sslKeystoreConfigurator); - addSslTrustStoreConfigurator(name, sslTruststoreConfigurator); + this.sslProviderComponent = sslProviderComponent; + addChild(sslProviderComponent); + inject(sslProviderComponent); } @Override @@ -49,6 +50,7 @@ public class ConnectorFactory extends SimpleComponent implements ConnectorConfig configureWithLegacyHttpConfig(legacyConfig, connectorBuilder); connectorBuilder.listenPort(listenPort); connectorBuilder.name(name); + ((ConnectorConfig.Producer)sslProviderComponent).getConfig(connectorBuilder); } public String getName() { @@ -152,31 +154,4 @@ public class ConnectorFactory extends SimpleComponent implements ConnectorConfig } } } - - private void addSslKeyStoreConfigurator(String name, Element sslKeystoreConfigurator) { - addSslConfigurator("ssl-keystore-configurator@" + name, - DefaultSslKeyStoreConfigurator.class, - sslKeystoreConfigurator); - } - - private void addSslTrustStoreConfigurator(String name, Element sslKeystoreConfigurator) { - addSslConfigurator("ssl-truststore-configurator@" + name, - DefaultSslTrustStoreConfigurator.class, - sslKeystoreConfigurator); - } - - private void addSslConfigurator(String idSpec, Class<?> defaultImplementation, Element configuratorElement) { - SimpleComponent configuratorComponent; - if (configuratorElement != null) { - String className = configuratorElement.getAttribute("class"); - String bundleName = configuratorElement.getAttribute("bundle"); - configuratorComponent = new SimpleComponent(new ComponentModel(idSpec, className, bundleName)); - } else { - configuratorComponent = - new SimpleComponent(new ComponentModel(idSpec, defaultImplementation.getName(), "jdisc_http_service")); - } - addChild(configuratorComponent); - inject(configuratorComponent); - } - } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/CustomSslProvider.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/CustomSslProvider.java new file mode 100644 index 00000000000..bc211925576 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/CustomSslProvider.java @@ -0,0 +1,29 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.http.ssl; + +import com.yahoo.component.ComponentId; +import com.yahoo.container.bundle.BundleInstantiationSpecification; +import com.yahoo.jdisc.http.ConnectorConfig; +import com.yahoo.osgi.provider.model.ComponentModel; +import com.yahoo.vespa.model.container.component.SimpleComponent; + +import static com.yahoo.component.ComponentSpecification.fromString; + +/** + * @author mortent + */ +public class CustomSslProvider extends SimpleComponent implements ConnectorConfig.Producer { + public static final String COMPONENT_ID_PREFIX = "ssl-provider@"; + + public CustomSslProvider(String serverName, String className, String bundle) { + super(new ComponentModel( + new BundleInstantiationSpecification(new ComponentId(COMPONENT_ID_PREFIX + serverName), + fromString(className), + fromString(bundle)))); + } + + @Override + public void getConfig(ConnectorConfig.Builder builder) { + builder.ssl.enabled(true); + } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/DefaultSslProvider.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/DefaultSslProvider.java new file mode 100644 index 00000000000..fc4b6b8cd0f --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/DefaultSslProvider.java @@ -0,0 +1,63 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.http.ssl; + +import com.yahoo.component.ComponentId; +import com.yahoo.container.bundle.BundleInstantiationSpecification; +import com.yahoo.jdisc.http.ConnectorConfig; +import com.yahoo.jdisc.http.ssl.DefaultSslContextFactoryProvider; +import com.yahoo.osgi.provider.model.ComponentModel; +import com.yahoo.vespa.model.container.component.SimpleComponent; + +import java.util.Optional; + +import static com.yahoo.component.ComponentSpecification.fromString; + +/** + * @author mortent + */ +public class DefaultSslProvider extends SimpleComponent implements ConnectorConfig.Producer { + public static final String COMPONENT_ID_PREFIX = "default-ssl-provider@"; + public static final String COMPONENT_CLASS = DefaultSslContextFactoryProvider.class.getName(); + public static final String COMPONENT_BUNDLE = "jdisc_http_service"; + + private final String privateKeyPath; + private final String certificatePath; + private final String caCertificatePath; + private final ConnectorConfig.Ssl.ClientAuth.Enum clientAuthentication; + + public DefaultSslProvider(String servername, String privateKeyPath, String certificatePath, String caCertificatePath, String clientAuthentication) { + super(new ComponentModel( + new BundleInstantiationSpecification(new ComponentId(COMPONENT_ID_PREFIX+servername), + fromString(COMPONENT_CLASS), + fromString(COMPONENT_BUNDLE)))); + this.privateKeyPath = privateKeyPath; + this.certificatePath = certificatePath; + this.caCertificatePath = caCertificatePath; + this.clientAuthentication = mapToConfigEnum(clientAuthentication); + } + + @Override + public void getConfig(ConnectorConfig.Builder builder) { + builder.ssl.enabled(true); + builder.ssl.privateKeyFile(privateKeyPath); + builder.ssl.certificateFile(certificatePath); + builder.ssl.caCertificateFile(Optional.ofNullable(caCertificatePath).orElse("")); + builder.ssl.clientAuth(clientAuthentication); + } + + public SimpleComponent getComponent() { + return new SimpleComponent(new ComponentModel(getComponentId().stringValue(), COMPONENT_CLASS, COMPONENT_BUNDLE)); + } + + private static ConnectorConfig.Ssl.ClientAuth.Enum mapToConfigEnum(String clientAuthValue) { + if ("disabled".equals(clientAuthValue)) { + return ConnectorConfig.Ssl.ClientAuth.Enum.DISABLED; + } else if ("want".equals(clientAuthValue)) { + return ConnectorConfig.Ssl.ClientAuth.Enum.WANT_AUTH; + } else if ("need".equals(clientAuthValue)) { + return ConnectorConfig.Ssl.ClientAuth.Enum.NEED_AUTH; + } else { + return ConnectorConfig.Ssl.ClientAuth.Enum.DISABLED; + } + } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/LegacySslProvider.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/LegacySslProvider.java new file mode 100644 index 00000000000..fedc8c4a843 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/LegacySslProvider.java @@ -0,0 +1,36 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.http.ssl; + +import com.yahoo.component.ComponentId; +import com.yahoo.container.bundle.BundleInstantiationSpecification; +import com.yahoo.jdisc.http.ConnectorConfig; +import com.yahoo.jdisc.http.ssl.SslContextFactoryProvider; +import com.yahoo.jdisc.http.ssl.LegacySslContextFactoryProvider; +import com.yahoo.osgi.provider.model.ComponentModel; +import com.yahoo.vespa.model.container.component.SimpleComponent; + +import static com.yahoo.component.ComponentSpecification.fromString; + +/** + * Provides a legacy implementation of {@link SslContextFactoryProvider} to be injected into non-ssl connectors and connectors using legacy ssl config override + * + * @author bjorncs + */ +public class LegacySslProvider extends SimpleComponent implements ConnectorConfig.Producer { + + public static final String COMPONENT_ID_PREFIX = "legacy-ssl-provider@"; + public static final String COMPONENT_CLASS = LegacySslContextFactoryProvider.class.getName(); + public static final String COMPONENT_BUNDLE = "jdisc_http_service"; + + public LegacySslProvider(String serverName) { + super(new ComponentModel( + new BundleInstantiationSpecification(new ComponentId(COMPONENT_ID_PREFIX + serverName), + fromString(COMPONENT_CLASS), + fromString(COMPONENT_BUNDLE)))); + } + + @Override + public void getConfig(ConnectorConfig.Builder builder) { + + } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/JettyConnectorBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/JettyConnectorBuilder.java index f88c091cd37..36736d66195 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/JettyConnectorBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/JettyConnectorBuilder.java @@ -5,14 +5,20 @@ import com.yahoo.config.model.builder.xml.XmlHelper; import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.text.XML; import com.yahoo.vespa.model.builder.xml.dom.VespaDomBuilder; +import com.yahoo.vespa.model.container.component.SimpleComponent; import com.yahoo.vespa.model.container.http.ConnectorFactory; +import com.yahoo.vespa.model.container.http.ssl.CustomSslProvider; +import com.yahoo.vespa.model.container.http.ssl.DefaultSslProvider; +import com.yahoo.vespa.model.container.http.ssl.LegacySslProvider; import org.w3c.dom.Element; +import java.util.Optional; import java.util.logging.Level; import java.util.logging.Logger; /** * @author Einar M R Rosenvinge + * @author mortent */ public class JettyConnectorBuilder extends VespaDomBuilder.DomConfigProducerBuilder<ConnectorFactory> { private static final Logger log = Logger.getLogger(JettyConnectorBuilder.class.getName()); @@ -32,9 +38,31 @@ public class JettyConnectorBuilder extends VespaDomBuilder.DomConfigProducerBuil legacyServerConfig = null; } } - Element sslKeystoreConfigurator = XML.getChild(serverSpec, "ssl-keystore-configurator"); - Element sslTruststoreConfigurator = XML.getChild(serverSpec, "ssl-truststore-configurator"); - return new ConnectorFactory(name, port, legacyServerConfig, sslKeystoreConfigurator, sslTruststoreConfigurator); + SimpleComponent sslProviderComponent = getSslConfigComponents(name, serverSpec); + return new ConnectorFactory(name, port, legacyServerConfig, sslProviderComponent); } + SimpleComponent getSslConfigComponents(String serverName, Element serverSpec) { + Element sslConfigurator = XML.getChild(serverSpec, "ssl"); + Element sslProviderConfigurator = XML.getChild(serverSpec, "ssl-provider"); + + if (sslConfigurator != null) { + String privateKeyFile = XML.getValue(XML.getChild(sslConfigurator, "private-key-file")); + String certificateFile = XML.getValue(XML.getChild(sslConfigurator, "certificate-file")); + Optional<String> caCertificateFile = XmlHelper.getOptionalChildValue(sslConfigurator, "ca-certificates-file"); + Optional<String> clientAuthentication = XmlHelper.getOptionalChildValue(sslConfigurator, "client-authentication"); + return new DefaultSslProvider( + serverName, + privateKeyFile, + certificateFile, + caCertificateFile.orElse(null), + clientAuthentication.orElse(null)); + } else if (sslProviderConfigurator != null) { + String className = sslProviderConfigurator.getAttribute("class"); + String bundle = sslProviderConfigurator.getAttribute("bundle"); + return new CustomSslProvider(serverName, className, bundle); + } else { + return new LegacySslProvider(serverName); + } + } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/JerseyHandler.java b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/JerseyHandler.java deleted file mode 100644 index 737882b703d..00000000000 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/JerseyHandler.java +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.model.container.jersey; - -import com.yahoo.config.model.producer.AbstractConfigProducer; -import com.yahoo.container.bundle.BundleInstantiationSpecification; -import com.yahoo.osgi.provider.model.ComponentModel; -import com.yahoo.vespa.model.container.component.Handler; - -/** - * @author gjoranv - * @since 5.6 - */ -public class JerseyHandler extends Handler<AbstractConfigProducer<?>> { - - public static final String BUNDLE = "container-jersey"; - public static final String CLASS = "com.yahoo.container.jdisc.jersey.JerseyHandler"; - - public JerseyHandler(String bindingPath) { - super(new ComponentModel(bundleSpec(CLASS, BUNDLE, bindingPath))); - } - - public static BundleInstantiationSpecification bundleSpec(String className, String bundle, String bindingPath) { - return BundleInstantiationSpecification.getFromStrings( - className + "-" + RestApi.idFromPath(bindingPath), - className, - bundle); - } -} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApi.java b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApi.java index 63825aa2a1b..be8209bcc4e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApi.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApi.java @@ -2,32 +2,24 @@ package com.yahoo.vespa.model.container.jersey; import com.yahoo.config.model.producer.AbstractConfigProducer; -import com.yahoo.container.config.jersey.JerseyInitConfig; -import com.yahoo.vespa.model.container.component.Component; - -import java.util.Optional; /** + * Represents a rest-api + * * @author gjoranv - * @since 5.6 */ -public class RestApi extends AbstractConfigProducer<AbstractConfigProducer<?>> implements - JerseyInitConfig.Producer -{ - public final boolean isJersey2; +public class RestApi extends AbstractConfigProducer<AbstractConfigProducer<?>> { + private final String bindingPath; - private final Component<?, ?> jerseyHandler; + private final Jersey2Servlet jerseyServlet; private RestApiContext restApiContext; - public RestApi(String bindingPath, boolean isJersey2) { + public RestApi(String bindingPath) { super(idFromPath(bindingPath)); this.bindingPath = bindingPath; - this.isJersey2 = isJersey2; - jerseyHandler = isJersey2 ? - createJersey2Servlet(this.bindingPath): - createJersey1Handler(this.bindingPath); - addChild(jerseyHandler); + jerseyServlet = createJersey2Servlet(this.bindingPath); + addChild(jerseyServlet); } public static String idFromPath(String path) { @@ -38,44 +30,20 @@ public class RestApi extends AbstractConfigProducer<AbstractConfigProducer<?>> i return new Jersey2Servlet(bindingPath); } - private static JerseyHandler createJersey1Handler(String bindingPath) { - JerseyHandler jerseyHandler = new JerseyHandler(bindingPath); - jerseyHandler.addServerBindings(getBindings(bindingPath)); - return jerseyHandler; - } - public String getBindingPath() { return bindingPath; } - @Override - public void getConfig(JerseyInitConfig.Builder builder) { - builder.jerseyMapping(bindingPath); - } - public void setRestApiContext(RestApiContext restApiContext) { this.restApiContext = restApiContext; addChild(restApiContext); - jerseyHandler.inject(restApiContext); + jerseyServlet.inject(restApiContext); } public RestApiContext getContext() { return restApiContext; } - public Optional<JerseyHandler> getJersey1Handler() { - return isJersey2 ? - Optional.empty(): - Optional.of((JerseyHandler)jerseyHandler); - } - - public Optional<Jersey2Servlet> getJersey2Servlet() { - return isJersey2 ? - Optional.of((Jersey2Servlet)jerseyHandler) : - Optional.empty(); - } - - private static String[] getBindings(String bindingPath) { - String bindingWithoutScheme = "://*/" + bindingPath + "/*"; - return new String[] {"http" + bindingWithoutScheme, "https" + bindingWithoutScheme}; + public Jersey2Servlet getJersey2Servlet() { + return jerseyServlet; } public void prepare() { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApiContext.java b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApiContext.java index 5e48a1b1951..7fce9d2b636 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApiContext.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApiContext.java @@ -22,7 +22,6 @@ import java.util.logging.Logger; /** * @author gjoranv - * @since 5.16 */ public class RestApiContext extends SimpleComponent implements JerseyBundlesConfig.Producer, @@ -87,10 +86,6 @@ public class RestApiContext extends SimpleComponent implements } } - public void addInjections(Map<String, String> injections) { - injectComponentForClass.putAll(injections); - } - @Override public void validate() throws Exception { super.validate(); @@ -117,7 +112,6 @@ public class RestApiContext extends SimpleComponent implements private Predicate<Component> isCycleGeneratingComponent = component -> { switch (component.getClassId().getName()) { case CONTAINER_CLASS: - case JerseyHandler.CLASS: case Jersey2Servlet.CLASS: case "com.yahoo.jdisc.http.server.jetty.JettyHttpServer": case "com.yahoo.container.handler.observability.ApplicationStatusHandler": diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/xml/RestApiBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/xml/RestApiBuilder.java index 245db3c014f..6728f0be29f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/xml/RestApiBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/xml/RestApiBuilder.java @@ -13,8 +13,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalAttribute; - /** * @author gjoranv * @since 5.6 @@ -24,8 +22,7 @@ public class RestApiBuilder extends VespaDomBuilder.DomConfigProducerBuilder<Res @Override protected RestApi doBuild(AbstractConfigProducer ancestor, Element spec) { String bindingPath = spec.getAttribute("path"); - boolean jersey2 = Boolean.parseBoolean(getOptionalAttribute(spec, "jersey2").orElse("false")); - RestApi restApi = new RestApi(bindingPath, jersey2); + RestApi restApi = new RestApi(bindingPath); restApi.setRestApiContext( createRestApiContext(ancestor, spec, bindingPath)); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChain.java b/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChain.java index 73430d8e453..9a984ca1917 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChain.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChain.java @@ -8,7 +8,6 @@ import com.yahoo.vespa.model.container.component.chain.Chain; * Represents a processing chain in the config model * * @author bratseth - * @since 5.1.6 */ public class ProcessingChain extends Chain<Processor> { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChains.java b/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChains.java index 644316ff652..32f0f373a92 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChains.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChains.java @@ -12,14 +12,14 @@ import java.util.List; * Root config producer for processing * * @author bratseth - * @since 5.1.6 */ public class ProcessingChains extends Chains<ProcessingChain> { - public static final String[] defaultBindings = new String[] - {"http://*/processing/*", "https://*/processing/*"}; + + public static final String[] defaultBindings = new String[] {"http://*/processing/*", "https://*/processing/*"}; public ProcessingChains(AbstractConfigProducer parent, String subId) { super(parent, subId); } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/processing/Processor.java b/config-model/src/main/java/com/yahoo/vespa/model/container/processing/Processor.java index e4f46be914b..3ad6484aaec 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/processing/Processor.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/processing/Processor.java @@ -9,7 +9,6 @@ import com.yahoo.vespa.model.container.component.chain.ChainedComponent; * Representation of a Processor in the configuration model * * @author bratseth - * @since 5.1.6 */ public class Processor extends ChainedComponent<ChainedComponentModel> { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java index c711f268534..36feba34680 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java @@ -35,11 +35,10 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains> QrStartConfig.Producer, QueryProfilesConfig.Producer, SemanticRulesConfig.Producer, - PageTemplatesConfig.Producer -{ + PageTemplatesConfig.Producer { private final List<AbstractSearchCluster> systems = new LinkedList<>(); - private Options options = null; + private final Options options; // For legacy qrs clusters only. private BinaryScaledAmount totalCacheSize = new BinaryScaledAmount(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java index 19d014e0a1d..ceb48732116 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java @@ -15,6 +15,7 @@ import java.util.*; /** * Config producer for the FederationSearcher. + * * @author Tony Vaagenes */ public class FederationSearcher extends Searcher<FederationSearcherModel> implements FederationConfig.Producer { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index cb4cf92a223..2d3f3036ccc 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -45,6 +45,7 @@ import com.yahoo.vespa.model.clients.ContainerDocumentApi; import com.yahoo.vespa.model.container.Container; import com.yahoo.vespa.model.container.ContainerCluster; import com.yahoo.vespa.model.container.ContainerModel; +import com.yahoo.vespa.model.container.ContainerModelEvaluation; import com.yahoo.vespa.model.container.IdentityProvider; import com.yahoo.vespa.model.container.SecretStore; import com.yahoo.vespa.model.container.component.Component; @@ -148,11 +149,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { return new VespaDomBuilder.DomConfigProducerBuilder<ContainerCluster>() { @Override protected ContainerCluster doBuild(AbstractConfigProducer ancestor, Element producerSpec) { - return new ContainerCluster(ancestor, - modelContext.getProducerId(), - modelContext.getProducerId(), - modelContext.vespaModel() != null ? modelContext.vespaModel().rankProfileList() - : RankProfileList.empty); + return new ContainerCluster(ancestor, modelContext.getProducerId(), modelContext.getProducerId()); } }.build(modelContext.getParentProducer(), spec); } @@ -166,6 +163,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { addServlets(spec, cluster); addProcessing(spec, cluster); addSearch(spec, cluster, context.getDeployState().getQueryProfiles(), context.getDeployState().getSemanticRules()); + addModelEvaluation(spec, cluster, context); addDocproc(spec, cluster); addDocumentApi(spec, cluster); // NOTE: Must be done after addSearch @@ -355,50 +353,56 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { } private void addServlets(Element spec, ContainerCluster cluster) { - for (Element servletElem : XML.getChildren(spec, "servlet")) { - cluster.addServlet( - new ServletBuilder().build(cluster, servletElem)); - } + for (Element servletElem : XML.getChildren(spec, "servlet")) + cluster.addServlet(new ServletBuilder().build(cluster, servletElem)); } private void addDocumentApi(Element spec, ContainerCluster cluster) { ContainerDocumentApi containerDocumentApi = buildDocumentApi(cluster, spec); - if (containerDocumentApi != null) { - cluster.setDocumentApi(containerDocumentApi); - } + if (containerDocumentApi == null) return; + + cluster.setDocumentApi(containerDocumentApi); } private void addDocproc(Element spec, ContainerCluster cluster) { ContainerDocproc containerDocproc = buildDocproc(cluster, spec); - if (containerDocproc != null) { - cluster.setDocproc(containerDocproc); + if (containerDocproc == null) return; + cluster.setDocproc(containerDocproc); - ContainerDocproc.Options docprocOptions = containerDocproc.options; - cluster.setMbusParams(new ContainerCluster.MbusParams( - docprocOptions.maxConcurrentFactor, docprocOptions.documentExpansionFactor, docprocOptions.containerCoreMemory)); - } + ContainerDocproc.Options docprocOptions = containerDocproc.options; + cluster.setMbusParams(new ContainerCluster.MbusParams( + docprocOptions.maxConcurrentFactor, docprocOptions.documentExpansionFactor, docprocOptions.containerCoreMemory)); } private void addSearch(Element spec, ContainerCluster cluster, QueryProfiles queryProfiles, SemanticRules semanticRules) { Element searchElement = XML.getChild(spec, "search"); - if (searchElement != null) { - addIncludes(searchElement); - cluster.setSearch(buildSearch(cluster, searchElement, queryProfiles, semanticRules)); + if (searchElement == null) return; - addSearchHandler(cluster, searchElement); - addGUIHandler(cluster); - validateAndAddConfiguredComponents(cluster, searchElement, "renderer", ContainerModelBuilder::validateRendererElement); - } + addIncludes(searchElement); + cluster.setSearch(buildSearch(cluster, searchElement, queryProfiles, semanticRules)); + + addSearchHandler(cluster, searchElement); + addGUIHandler(cluster); + validateAndAddConfiguredComponents(cluster, searchElement, "renderer", ContainerModelBuilder::validateRendererElement); + } + + private void addModelEvaluation(Element spec, ContainerCluster cluster, ConfigModelContext context) { + Element modelEvaluationElement = XML.getChild(spec, "model-evaluation"); + if (modelEvaluationElement == null) return; + + RankProfileList profiles = + context.vespaModel() != null ? context.vespaModel().rankProfileList() : RankProfileList.empty; + cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles)); } private void addProcessing(Element spec, ContainerCluster cluster) { Element processingElement = XML.getChild(spec, "processing"); - if (processingElement != null) { - addIncludes(processingElement); - cluster.setProcessingChains(new DomProcessingBuilder(null).build(cluster, processingElement), - serverBindings(processingElement, ProcessingChains.defaultBindings)); - validateAndAddConfiguredComponents(cluster, processingElement, "renderer", ContainerModelBuilder::validateRendererElement); - } + if (processingElement == null) return; + + addIncludes(processingElement); + cluster.setProcessingChains(new DomProcessingBuilder(null).build(cluster, processingElement), + serverBindings(processingElement, ProcessingChains.defaultBindings)); + validateAndAddConfiguredComponents(cluster, processingElement, "renderer", ContainerModelBuilder::validateRendererElement); } private ContainerSearch buildSearch(ContainerCluster containerCluster, Element producerSpec, diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/Content.java b/config-model/src/main/java/com/yahoo/vespa/model/content/Content.java index d3709e88f29..ec68243ec9d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/Content.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/Content.java @@ -299,7 +299,7 @@ public class Content extends ConfigModel { AbstractConfigProducer parent = root.getChildren().get(ContainerModel.DOCPROC_RESERVED_NAME); if (parent == null) parent = new SimpleConfigProducer(root, ContainerModel.DOCPROC_RESERVED_NAME); - ContainerCluster indexingCluster = new ContainerCluster(parent, "cluster." + indexerName, indexerName, RankProfileList.empty); + ContainerCluster indexingCluster = new ContainerCluster(parent, "cluster." + indexerName, indexerName); ContainerModel indexingClusterModel = new ContainerModel(modelContext.withParent(parent).withId(indexingCluster.getSubId())); indexingClusterModel.setCluster(indexingCluster); modelContext.getConfigModelRepoAdder().add(indexingClusterModel); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java index cce367ed611..f15ba547894 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java @@ -443,8 +443,7 @@ public class ContentCluster extends AbstractConfigProducer implements ContainerCluster clusterControllers = new ContainerCluster(parent, name, name, - new ClusterControllerClusterVerifier(), - RankProfileList.empty); + new ClusterControllerClusterVerifier()); List<Container> containers = new ArrayList<>(); // Add a cluster controller on each config server (there is always at least one). if (clusterControllers.getContainers().isEmpty()) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 6acb9ff1f7e..e2236feb336 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -1,4 +1,5 @@ -package com.yahoo.searchdefinition.expressiontransforms; +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; @@ -11,6 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; @@ -18,7 +20,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -63,14 +64,14 @@ import java.util.stream.Collectors; */ public class ConvertedModel { - private final String modelName; + private final ModelName modelName; private final String modelDescription; private final ImmutableMap<String, RankingExpression> expressions; /** The source importedModel, or empty if this was created from a stored converted model */ private final Optional<ImportedModel> sourceModel; - private ConvertedModel(String modelName, + private ConvertedModel(ModelName modelName, String modelDescription, Map<String, RankingExpression> expressions, Optional<ImportedModel> sourceModel) { @@ -83,10 +84,13 @@ public class ConvertedModel { /** * Create and store a converted model for a rank profile given from either an imported model, * or (if unavailable) from stored application package data. + * + * @param modelPath the path to the model + * @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory */ - public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) { + public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) { File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath); - String modelName = context.rankProfile().getName() + "." + toModelName(modelPath); // must be unique to each profile + ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile); if (sourceModel.exists()) return fromSource(modelName, modelPath.toString(), @@ -99,7 +103,7 @@ public class ConvertedModel { context.rankProfile()); } - public static ConvertedModel fromSource(String modelName, + public static ConvertedModel fromSource(ModelName modelName, String modelDescription, RankProfile rankProfile, QueryProfileRegistry queryProfileRegistry, @@ -111,7 +115,7 @@ public class ConvertedModel { Optional.of(importedModel)); } - public static ConvertedModel fromStore(String modelName, + public static ConvertedModel fromStore(ModelName modelName, String modelDescription, RankProfile rankProfile) { ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName); @@ -240,9 +244,12 @@ public class ConvertedModel { profile.addConstant(constantName, asValue(constantValue)); } - private static void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Set<String> constantsReplacedByMacros, - String constantName, Tensor constantValue) { + private static void transformLargeConstant(ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + Set<String> constantsReplacedByMacros, + String constantName, + Tensor constantValue) { RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); if (macroOverridingConstant != null) { TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); @@ -255,7 +262,7 @@ public class ConvertedModel { Path constantPath = store.writeLargeConstant(constantName, constantValue); if ( ! profile.rankingConstants().asMap().containsKey(constantName)) { profile.rankingConstants().add(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); + constantPath.toString())); } } } @@ -491,10 +498,6 @@ public class ConvertedModel { return new TensorValue(tensor); } - private static String toModelName(Path modelPath) { - return modelPath.toString().replace("/", "_"); - } - @Override public String toString() { return "model '" + modelName + "'"; } @@ -513,7 +516,7 @@ public class ConvertedModel { private final ApplicationPackage application; private final ModelFiles modelFiles; - ModelStore(ApplicationPackage application, String modelName) { + ModelStore(ApplicationPackage application, ModelName modelName) { this.application = application; this.modelFiles = new ModelFiles(modelName); } @@ -616,15 +619,19 @@ public class ConvertedModel { .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); // Write content explicitly as a file on the file system as this is distributed using file distribution - createIfNeeded(constantsPath); - IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + // - but only if this is a global model to avoid writing the same constants for each rank profile + // where they are used + if (modelFiles.modelName.isGlobal()) { + createIfNeeded(constantsPath); + IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + } return correct(constantPath); } private List<Pair<String, Tensor>> readSmallConstants() { try { ApplicationFile file = application.getFile(modelFiles.smallConstantsPath()); - if (!file.exists()) return Collections.emptyList(); + if ( ! file.exists()) return Collections.emptyList(); List<Pair<String, Tensor>> constants = new ArrayList<>(); BufferedReader reader = new BufferedReader(file.createReader()); @@ -676,20 +683,24 @@ public class ConvertedModel { static class ModelFiles { - String modelName; + ModelName modelName; - public ModelFiles(String modelName) { + public ModelFiles(ModelName modelName) { this.modelName = modelName; } /** Files stored below this path will be replicated in zookeeper */ public Path storedModelReplicatedPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName); + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName.fullName()); } - /** Files stored below this path will not be replicated in zookeeper */ - public Path storedModelPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName); + /** + * Files stored below this path will not be replicated in zookeeper. + * Large constants are only stored under the global (not rank-profile-specific) + * path to avoid storing the same large constant multiple times. + */ + public Path storedGlobalModelPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName.localName()); } public Path expressionPath(String name) { @@ -701,12 +712,12 @@ public class ConvertedModel { } public Path smallConstantsPath() { - return storedModelPath().append("constants.txt"); + return storedModelReplicatedPath().append("constants.txt"); } /** Path to the large (ranking) constants directory */ public Path largeConstantsContentPath() { - return storedModelPath().append("constants"); + return storedGlobalModelPath().append("constants"); } /** Path to the large (ranking) constants directory */ @@ -721,53 +732,4 @@ public class ConvertedModel { } - /** Encapsulates the arguments of a specific model output */ - static class FeatureArguments { - - /** Optional arguments */ - private final Optional<String> signature, output; - - public FeatureArguments(Arguments arguments) { - this(optionalArgument(1, arguments), - optionalArgument(2, arguments)); - } - - public FeatureArguments(Optional<String> signature, Optional<String> output) { - this.signature = signature; - this.output = output; - } - - public Optional<String> signature() { return signature; } - public Optional<String> output() { return output; } - - public String toName() { - return (signature.isPresent() ? signature.get() : "") + - (output.isPresent() ? "." + output.get() : ""); - } - - private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } - - public static String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } - - private static String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); - } - - private static boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; - } - - } - } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java new file mode 100644 index 00000000000..fda49af6178 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java @@ -0,0 +1,61 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; + +import java.util.Optional; + +/** + * Encapsulates the arguments of a specific model output + * + * @author bratseth + */ +public class FeatureArguments { + + /** Optional arguments */ + private final Optional<String> signature, output; + + public FeatureArguments(Arguments arguments) { + this(optionalArgument(1, arguments), + optionalArgument(2, arguments)); + } + + public FeatureArguments(Optional<String> signature, Optional<String> output) { + this.signature = signature; + this.output = output; + } + + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } + + public String toName() { + return (signature.isPresent() ? signature.get() : "") + + (output.isPresent() ? "." + output.get() : ""); + } + + private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + public static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private static String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private static boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java new file mode 100644 index 00000000000..2c7dc6b337d --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java @@ -0,0 +1,62 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.yahoo.path.Path; + +/** + * Models used in a rank profile has the rank profile name as name space while gGlobal model names have no namespace + * + * @author bratseth + */ +public class ModelName { + + /** The namespace, or null if none */ + private String namespace; + private String name; + private String fullName; + + public ModelName(String name) { + this(null, name); + } + + public ModelName(String namespace, Path modelPath, boolean pathIsFile) { + this(namespace, + stripFileEndingIfFile(modelPath, pathIsFile).toString().replace("/", "_")); + } + + private ModelName(String namespace, String name) { + this.namespace = namespace; + this.name = name; + this.fullName = (namespace != null ? namespace + "." : "") + name; + } + + private static Path stripFileEndingIfFile(Path path, boolean pathIsFile) { + if ( ! pathIsFile) return path; + int dotIndex = path.last().lastIndexOf("."); + if (dotIndex <= 0) return path; + return path.withLast(path.last().substring(0, dotIndex)); + } + + /** Returns true if the local name of this is not in a namespace */ + public boolean isGlobal() { return namespace == null; } + + /** Returns the namespace, or null if this is global */ + public String namespace() { return namespace; } + public String localName() { return name; } + public String fullName() { return fullName; } + + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if ( ! (o instanceof ModelName)) return false; + return ((ModelName)o).fullName.equals(this.fullName); + } + + @Override + public int hashCode() { return fullName.hashCode(); } + + @Override + public String toString() { return fullName; } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/routing/Protocol.java b/config-model/src/main/java/com/yahoo/vespa/model/routing/Protocol.java index ad684894176..49596aa0ddf 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/routing/Protocol.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/routing/Protocol.java @@ -12,18 +12,10 @@ import com.yahoo.messagebus.routing.RoutingTableSpec; */ public interface Protocol { - /** - * Returns the specification for the routing table of this protocol. - * - * @return The routing table spec. - */ - public RoutingTableSpec getRoutingTableSpec(); + /** Returns the specification for the routing table of this protocol. */ + RoutingTableSpec getRoutingTableSpec(); - /** - * Returns the specification of the application as seen by this protocol. - * - * @return The application spec. - */ - public ApplicationSpec getApplicationSpec(); + /** Returns the specification of the application as seen by this protocol. */ + ApplicationSpec getApplicationSpec(); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/routing/Routing.java b/config-model/src/main/java/com/yahoo/vespa/model/routing/Routing.java index 16f51935f2a..2403594d331 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/routing/Routing.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/routing/Routing.java @@ -21,7 +21,7 @@ public class Routing extends ConfigModel { private final List<String> errors = new ArrayList<>(); private ApplicationSpec explicitApplication = null; private RoutingSpec explicitRouting = null; - private List<Protocol> protocols = new ArrayList<>(); + private final List<Protocol> protocols = new ArrayList<>(); private RoutingSpec derivedRouting; public Routing(ConfigModelContext modelContext) { @@ -91,7 +91,7 @@ public class Routing extends ConfigModel { } public void getConfig(MessagebusConfig.Builder builder) { - if (derivedRouting==null) { + if (derivedRouting == null) { // The error list should be populated then return; } @@ -198,4 +198,5 @@ public class Routing extends ConfigModel { public List<String> getErrors() { return Collections.unmodifiableList(errors); } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java index 83da5d96418..fbbf029d5f1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java @@ -1,16 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.search; -import com.yahoo.config.FileReference; import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.config.model.producer.UserConfigRepo; import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig; import com.yahoo.search.config.IndexInfoConfig; -import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.configdefinition.IlscriptsConfig; -import com.yahoo.vespa.model.utils.FileSender; import java.util.ArrayList; import java.util.LinkedList; @@ -36,14 +33,8 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer protected List<SearchDefinitionSpec> localSDS = new LinkedList<>(); public void prepareToDistributeFiles(List<SearchNode> backends) { - for (SearchDefinitionSpec sds : localSDS) { - for (RankingConstant constant : sds.getSearchDefinition().getSearch().rankingConstants().asMap().values()) { - FileReference reference = (constant.getPathType() == RankingConstant.PathType.FILE) - ? FileSender.sendFileToServices(constant.getFileName(), backends) - : FileSender.sendUriToServices(constant.getUri(), backends); - constant.setFileReference(reference.value()); - } - } + for (SearchDefinitionSpec sds : localSDS) + sds.getSearchDefinition().getSearch().rankingConstants().sendTo(backends); } public static final class IndexingMode { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java index a6bf51a2503..b29ed0fc25b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java @@ -75,16 +75,7 @@ public class DocumentDatabase extends AbstractConfigProducer implements @Override public void getConfig(RankingConstantsConfig.Builder builder) { - for (RankingConstant constant : derivedCfg.getSearch().rankingConstants().asMap().values()) { - if ("".equals(constant.getFileReference())) { - System.err.println("INVALID rank constant "+constant.getName()+" [missing file reference]"); // TODO: Throw or log warning - continue; - } - builder.constant(new RankingConstantsConfig.Constant.Builder() - .name(constant.getName()) - .fileref(constant.getFileReference()) - .type(constant.getType())); - } + derivedCfg.getRankProfileList().getConfig(builder); } @Override diff --git a/config-model/src/main/python/ES_Vespa_parser.py b/config-model/src/main/python/ES_Vespa_parser.py index 477b0db4744..4721dbe8128 100644 --- a/config-model/src/main/python/ES_Vespa_parser.py +++ b/config-model/src/main/python/ES_Vespa_parser.py @@ -61,7 +61,7 @@ class ElasticSearchParser: _all_enabled = data[index]["mappings"][type]["_all"]["enabled"] if not _all_enabled: self._all = False - print(" > All fields in the document type '" + type + "' is not searchable. Go inside " + self.path + type + ".sd to add which fields that should be searchable") + print(" > Not all fields in the document type '" + type + "' are searchable. Edit " + self.path + "searchdefinitions/" + type + ".sd to control which fields are searchable") except KeyError: print(" > All fields in the document type '" + type + "' is searchable") @@ -108,7 +108,7 @@ class ElasticSearchParser: vespa_docs.close() unparsed_document_file.close() - print(" > Parsed all documents '" + ", ".join(self.types) + "'" + "' at '" + file_path + "'") + print(" > Parsed all documents '" + ", ".join(self.types) + "' at '" + file_path + "'") def createSearchDefinition(self, type, type_mapping): file_path = self.path + "searchdefinitions/" + type + ".sd" @@ -117,6 +117,10 @@ class ElasticSearchParser: new_sd.write(" document " + type + " {\n") for key, item in type_mapping.items(): + type = self.get_type(item) + if(type == "nested"): + print(" > SKIPPING FIELD " + key + ", this tool is not yet able to convert nested fields") + continue new_sd.write(" field " + key + " type " + self.get_type(item) + " {\n") new_sd.write(" indexing: " + self.get_indexing(key, self.get_type(item)) + "\n") new_sd.write(" }\n") @@ -180,6 +184,8 @@ class ElasticSearchParser: def get_type(self, type): return { + "integer": "int", + "string": "string", # for compatability with older ES versions "text": "string", "keyword": "string", "date": "string", @@ -189,6 +195,7 @@ class ElasticSearchParser: "ip": "text", "byte": "byte", "float": "float", + "nested": "nested" }[type] diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc index a8c763e43b7..4934ce113bb 100644 --- a/config-model/src/main/resources/schema/containercluster.rnc +++ b/config-model/src/main/resources/schema/containercluster.rnc @@ -16,6 +16,7 @@ ContainerServices = SearchInContainer? & DocprocInContainer? & ProcessingInContainer? & + ModelEvaluation? & DocumentApi? & Components* & Component* & @@ -62,8 +63,7 @@ Filtering = element filtering { HttpServer = element server { attribute port { xsd:nonNegativeInteger } & ComponentId & - element ssl-keystore-configurator { BundleSpec }? & # FOR INTERNAL USE ONLY - SUBJECT TO CHANGE - element ssl-truststore-configurator { BundleSpec }? & # FOR INTERNAL USE ONLY - SUBJECT TO CHANGE + (Ssl | SslProvider)? & GenericConfig* } @@ -85,6 +85,21 @@ SecretStore = element secret-store { } + } +ModelEvaluation = element model-evaluation { + empty +} + +Ssl = element ssl { + element private-key-file { string } & + element certificate-file { string } & + element ca-certificates-file { string }? & + element client-authentication { string "disabled" | string "want" | string "need" }? +} + +SslProvider = element ssl-provider { + BundleSpec +} + # REST-API: RestApi = element rest-api { diff --git a/config-model/src/test/cfg/application/ml_serving/services.xml b/config-model/src/test/cfg/application/ml_serving/services.xml index 42528336bc5..41f44e04c99 100644 --- a/config-model/src/test/cfg/application/ml_serving/services.xml +++ b/config-model/src/test/cfg/application/ml_serving/services.xml @@ -3,6 +3,7 @@ <services version="1.0"> <container version="1.0"> + <model-evaluation/> <nodes> <node hostalias="node1" /> </nodes> diff --git a/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax.onnx b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax.onnx Binary files differnew file mode 100644 index 00000000000..a86019bf53a --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax.onnx diff --git a/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/saved_model.pbtxt b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/saved_model.pbtxt new file mode 100644 index 00000000000..05b0e4e0f29 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/saved_model.pbtxt @@ -0,0 +1,5039 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "ApplyGradientDescent" + input_arg { + name: "var" + type_attr: "T" + is_ref: true + } + input_arg { + name: "alpha" + type_attr: "T" + } + input_arg { + name: "delta" + type_attr: "T" + } + output_arg { + name: "out" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } + } + op { + name: "ArgMax" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dimension" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "output_type" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "output_type" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Assign" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + attr { + name: "validate_shape" + type: "bool" + default_value { + b: true + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: true + } + } + allows_uninitialized_input: true + } + op { + name: "BroadcastGradientArgs" + input_arg { + name: "s0" + type_attr: "T" + } + input_arg { + name: "s1" + type_attr: "T" + } + output_arg { + name: "r0" + type_attr: "T" + } + output_arg { + name: "r1" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Cast" + input_arg { + name: "x" + type_attr: "SrcT" + } + output_arg { + name: "y" + type_attr: "DstT" + } + attr { + name: "SrcT" + type: "type" + } + attr { + name: "DstT" + type: "type" + } + } + op { + name: "ConcatV2" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + input_arg { + name: "axis" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 2 + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "Equal" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type: DT_BOOL + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_QUINT8 + type: DT_QINT8 + type: DT_QINT32 + type: DT_STRING + type: DT_BOOL + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "ExpandDims" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dim" + type_attr: "Tdim" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tdim" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Fill" + input_arg { + name: "dims" + type: DT_INT32 + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "FloorDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { + name: "Mean" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + is_stateful: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "Prod" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RealDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Reshape" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "shape" + type_attr: "Tshape" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshape" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "Shape" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "Slice" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "begin" + type_attr: "Index" + } + input_arg { + name: "size" + type_attr: "Index" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Index" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "SoftmaxCrossEntropyWithLogits" + input_arg { + name: "features" + type_attr: "T" + } + input_arg { + name: "labels" + type_attr: "T" + } + output_arg { + name: "loss" + type_attr: "T" + } + output_arg { + name: "backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sum" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Tile" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "multiples" + type_attr: "Tmultiples" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tmultiples" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "VariableV2" + output_arg { + name: "ref" + type_attr: "dtype" + is_ref: true + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + } + op { + name: "ZerosLike" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + } + tags: "serve" + tensorflow_version: "1.4.1" + tensorflow_git_version: "v1.4.0-19-ga52c8d9" + } + graph_def { + node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + node { + name: "layer/zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "layer/Variable" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "layer/Variable/Assign" + op: "Assign" + input: "layer/Variable" + input: "layer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "layer/Variable/read" + op: "Identity" + input: "layer/Variable" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "layer/zeros_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "layer/Variable_1" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "layer/Variable_1/Assign" + op: "Assign" + input: "layer/Variable_1" + input: "layer/zeros_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "layer/Variable_1/read" + op: "Identity" + input: "layer/Variable_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "layer/MatMul" + op: "MatMul" + input: "Placeholder" + input: "layer/Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "layer/add" + op: "Add" + input: "layer/MatMul" + input: "layer/Variable_1/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "Rank" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape" + op: "Shape" + input: "layer/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Rank_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_1" + op: "Shape" + input: "layer/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape" + op: "Reshape" + input: "layer/add" + input: "concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Rank_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_2" + op: "Shape" + input: "Placeholder_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub_1/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_1/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat_1/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat_1/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape_1" + op: "Reshape" + input: "Placeholder_1" + input: "concat_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape" + input: "Reshape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Sub_2/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_2/begin" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Reshape_2" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean" + op: "Mean" + input: "Reshape_2" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape_1" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Shape_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_2" + input: "gradients/Mean_grad/Const_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_2_grad/Shape" + op: "Shape" + input: "SoftmaxCrossEntropyWithLogits" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_2_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_2_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/zeros_like" + op: "ZerosLike" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_2_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_grad/Shape" + op: "Shape" + input: "layer/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/Shape" + op: "Shape" + input: "layer/MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/layer/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } + } + node { + name: "gradients/layer/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/layer/add_grad/Shape" + input: "gradients/layer/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/Sum" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/layer/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/layer/add_grad/Reshape" + op: "Reshape" + input: "gradients/layer/add_grad/Sum" + input: "gradients/layer/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/layer/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/layer/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/layer/add_grad/Sum_1" + input: "gradients/layer/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/layer/add_grad/Reshape" + input: "^gradients/layer/add_grad/Reshape_1" + } + node { + name: "gradients/layer/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/layer/add_grad/Reshape" + input: "^gradients/layer/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/layer/add_grad/Reshape_1" + input: "^gradients/layer/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/layer/add_grad/tuple/control_dependency" + input: "layer/Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "gradients/layer/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Placeholder" + input: "gradients/layer/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "gradients/layer/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/layer/MatMul_grad/MatMul" + input: "^gradients/layer/MatMul_grad/MatMul_1" + } + node { + name: "gradients/layer/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/layer/MatMul_grad/MatMul" + input: "^gradients/layer/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "gradients/layer/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/layer/MatMul_grad/MatMul_1" + input: "^gradients/layer/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "GradientDescent/learning_rate" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node { + name: "GradientDescent/update_layer/Variable/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "layer/Variable" + input: "GradientDescent/learning_rate" + input: "gradients/layer/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent/update_layer/Variable_1/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "layer/Variable_1" + input: "GradientDescent/learning_rate" + input: "gradients/layer/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_layer/Variable/ApplyGradientDescent" + input: "^GradientDescent/update_layer/Variable_1/ApplyGradientDescent" + } + node { + name: "init" + op: "NoOp" + input: "^layer/Variable/Assign" + input: "^layer/Variable_1/Assign" + } + node { + name: "ArgMax/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax" + op: "ArgMax" + input: "layer/add" + input: "ArgMax/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "ArgMax_1/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax_1" + op: "ArgMax" + input: "Placeholder_1" + input: "ArgMax_1/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "Equal" + op: "Equal" + input: "ArgMax" + input: "ArgMax_1" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Cast_1" + op: "Cast" + input: "Equal" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean_1" + op: "Mean" + input: "Cast_1" + input: "Const_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "save/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "model" + } + } + } + } + node { + name: "save/StringJoin/inputs_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp_65caff16d5244276b9828b0dab21b157/part" + } + } + } + } + node { + name: "save/StringJoin" + op: "StringJoin" + input: "save/Const" + input: "save/StringJoin/inputs_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "separator" + value { + s: "" + } + } + } + node { + name: "save/num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "save/ShardedFilename/shard" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "save/ShardedFilename" + op: "ShardedFilename" + input: "save/StringJoin" + input: "save/ShardedFilename/shard" + input: "save/num_shards" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/SaveV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "layer/Variable" + string_val: "layer/Variable_1" + } + } + } + } + node { + name: "save/SaveV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "" + string_val: "" + } + } + } + } + node { + name: "save/SaveV2" + op: "SaveV2" + input: "save/ShardedFilename" + input: "save/SaveV2/tensor_names" + input: "save/SaveV2/shape_and_slices" + input: "layer/Variable" + input: "layer/Variable_1" + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + } + node { + name: "save/control_dependency" + op: "Identity" + input: "save/ShardedFilename" + input: "^save/SaveV2" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@save/ShardedFilename" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "save/ShardedFilename" + input: "^save/control_dependency" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "save/MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "save/MergeV2Checkpoints/checkpoint_prefixes" + input: "save/Const" + attr { + key: "delete_old_dirs" + value { + b: true + } + } + } + node { + name: "save/Identity" + op: "Identity" + input: "save/Const" + input: "^save/control_dependency" + input: "^save/MergeV2Checkpoints" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/RestoreV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "layer/Variable" + } + } + } + } + node { + name: "save/RestoreV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2/tensor_names" + input: "save/RestoreV2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign" + op: "Assign" + input: "layer/Variable" + input: "save/RestoreV2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_1/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "layer/Variable_1" + } + } + } + } + node { + name: "save/RestoreV2_1/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_1" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_1/tensor_names" + input: "save/RestoreV2_1/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_1" + op: "Assign" + input: "layer/Variable_1" + input: "save/RestoreV2_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/restore_shard" + op: "NoOp" + input: "^save/Assign" + input: "^save/Assign_1" + } + node { + name: "save/restore_all" + op: "NoOp" + input: "^save/restore_shard" + } + versions { + producer: 24 + } + } + saver_def { + filename_tensor_name: "save/Const:0" + save_tensor_name: "save/Identity:0" + restore_op_name: "save/restore_all" + max_to_keep: 5 + sharded: true + keep_checkpoint_every_n_hours: 10000.0 + version: V2 + } + collection_def { + key: "train_op" + value { + node_list { + value: "GradientDescent" + } + } + } + collection_def { + key: "trainable_variables" + value { + bytes_list { + value: "\n\020layer/Variable:0\022\025layer/Variable/Assign\032\025layer/Variable/read:02\rlayer/zeros:0" + value: "\n\022layer/Variable_1:0\022\027layer/Variable_1/Assign\032\027layer/Variable_1/read:02\017layer/zeros_1:0" + } + } + } + collection_def { + key: "variables" + value { + bytes_list { + value: "\n\020layer/Variable:0\022\025layer/Variable/Assign\032\025layer/Variable/read:02\rlayer/zeros:0" + value: "\n\022layer/Variable_1:0\022\027layer/Variable_1/Assign\032\027layer/Variable_1/read:02\017layer/zeros_1:0" + } + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "x" + value { + name: "Placeholder:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + outputs { + key: "y" + value { + name: "layer/add:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.data-00000-of-00001 b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.data-00000-of-00001 Binary files differnew file mode 100644 index 00000000000..826b0280abf --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.data-00000-of-00001 diff --git a/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.index b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.index Binary files differnew file mode 100644 index 00000000000..d00fc5b06ed --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.index diff --git a/config-model/src/test/cfg/application/ml_serving_name_collision/services.xml b/config-model/src/test/cfg/application/ml_serving_name_collision/services.xml new file mode 100644 index 00000000000..41f44e04c99 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_name_collision/services.xml @@ -0,0 +1,13 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <container version="1.0"> + <model-evaluation/> + <nodes> + <node hostalias="node1" /> + </nodes> + + </container> + +</services> diff --git a/config-model/src/test/cfg/application/ml_serving_not_activated/models/xgboost.2.2.json b/config-model/src/test/cfg/application/ml_serving_not_activated/models/xgboost.2.2.json new file mode 100644 index 00000000000..f8949b47e52 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_not_activated/models/xgboost.2.2.json @@ -0,0 +1,19 @@ +[ + { "nodeid": 0, "depth": 0, "split": "f29", "split_condition": -0.1234567, "yes": 1, "no": 2, "missing": 1, "children": [ + { "nodeid": 1, "depth": 1, "split": "f56", "split_condition": -0.242398, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 1.71218 }, + { "nodeid": 4, "leaf": -1.70044 } + ]}, + { "nodeid": 2, "depth": 1, "split": "f109", "split_condition": 0.8723473, "yes": 5, "no": 6, "missing": 5, "children": [ + { "nodeid": 5, "leaf": -1.94071 }, + { "nodeid": 6, "leaf": 1.85965 } + ]} + ]}, + { "nodeid": 0, "depth": 0, "split": "f60", "split_condition": -0.482947, "yes": 1, "no": 2, "missing": 1, "children": [ + { "nodeid": 1, "depth": 1, "split": "f29", "split_condition": -4.2387498, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 0.784718 }, + { "nodeid": 4, "leaf": -0.96853 } + ]}, + { "nodeid": 2, "leaf": -6.23624 } + ]} +]
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/ml_serving_not_activated/services.xml b/config-model/src/test/cfg/application/ml_serving_not_activated/services.xml new file mode 100644 index 00000000000..9d8b7a81201 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_not_activated/services.xml @@ -0,0 +1,13 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <container version="1.0"> + <!-- No <model-evaluation/> tag --> + <nodes> + <node hostalias="node1" /> + </nodes> + + </container> + +</services> diff --git a/config-model/src/test/cfg/search/compare/complex/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg b/config-model/src/test/cfg/search/compare/complex/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg index e9368ca2662..34cf36e450e 100644 --- a/config-model/src/test/cfg/search/compare/complex/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg +++ b/config-model/src/test/cfg/search/compare/complex/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg @@ -16,7 +16,6 @@ partition 0 ptport 19115 transport "" transportnodelay true -transportdirectwrite false packetcompresslimit 1024 packetcompresslevel 3 packetcompresstype LZ4 diff --git a/config-model/src/test/cfg/search/compare/complex/search/cluster.music/tlds/tld.1/fdispatchrc.MODEL.cfg b/config-model/src/test/cfg/search/compare/complex/search/cluster.music/tlds/tld.1/fdispatchrc.MODEL.cfg index aa48d5fec79..d11651139fa 100644 --- a/config-model/src/test/cfg/search/compare/complex/search/cluster.music/tlds/tld.1/fdispatchrc.MODEL.cfg +++ b/config-model/src/test/cfg/search/compare/complex/search/cluster.music/tlds/tld.1/fdispatchrc.MODEL.cfg @@ -16,7 +16,6 @@ partition 0 ptport 19118 transport "" transportnodelay true -transportdirectwrite false packetcompresslimit 1024 packetcompresslevel 3 packetcompresstype LZ4 diff --git a/config-model/src/test/cfg/search/compare/complex/search/cluster.rt/tlds/tld.0/fdispatchrc.MODEL.cfg b/config-model/src/test/cfg/search/compare/complex/search/cluster.rt/tlds/tld.0/fdispatchrc.MODEL.cfg index ac173575923..2d3ea040969 100644 --- a/config-model/src/test/cfg/search/compare/complex/search/cluster.rt/tlds/tld.0/fdispatchrc.MODEL.cfg +++ b/config-model/src/test/cfg/search/compare/complex/search/cluster.rt/tlds/tld.0/fdispatchrc.MODEL.cfg @@ -16,7 +16,6 @@ partition 0 ptport 19149 transport "" transportnodelay true -transportdirectwrite false packetcompresslimit 1024 packetcompresslevel 3 packetcompresstype LZ4 diff --git a/config-model/src/test/cfg/search/compare/optionals/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg b/config-model/src/test/cfg/search/compare/optionals/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg index aa48d5fec79..d11651139fa 100644 --- a/config-model/src/test/cfg/search/compare/optionals/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg +++ b/config-model/src/test/cfg/search/compare/optionals/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg @@ -16,7 +16,6 @@ partition 0 ptport 19118 transport "" transportnodelay true -transportdirectwrite false packetcompresslimit 1024 packetcompresslevel 3 packetcompresstype LZ4 diff --git a/config-model/src/test/cfg/search/compare/simple/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg b/config-model/src/test/cfg/search/compare/simple/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg index ce9e77ffdbc..9ddce0a52a9 100644 --- a/config-model/src/test/cfg/search/compare/simple/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg +++ b/config-model/src/test/cfg/search/compare/simple/search/cluster.music/tlds/tld.0/fdispatchrc.MODEL.cfg @@ -16,7 +16,6 @@ partition 0 ptport 19111 transport "" transportnodelay true -transportdirectwrite false packetcompresslimit 1024 packetcompresslevel 3 packetcompresstype LZ4 diff --git a/config-model/src/test/cfg/search/compare/twoFeedTargetClusters/search/cluster.music1/tlds/tld.0/fdispatchrc.MODEL.cfg b/config-model/src/test/cfg/search/compare/twoFeedTargetClusters/search/cluster.music1/tlds/tld.0/fdispatchrc.MODEL.cfg index d4135d10175..d5ab3679c99 100644 --- a/config-model/src/test/cfg/search/compare/twoFeedTargetClusters/search/cluster.music1/tlds/tld.0/fdispatchrc.MODEL.cfg +++ b/config-model/src/test/cfg/search/compare/twoFeedTargetClusters/search/cluster.music1/tlds/tld.0/fdispatchrc.MODEL.cfg @@ -16,7 +16,6 @@ partition 0 ptport 19108 transport "" transportnodelay true -transportdirectwrite false packetcompresslimit 1024 packetcompresslevel 3 packetcompresstype LZ4 diff --git a/config-model/src/test/cfg/search/compare/twoFeedTargetClusters/search/cluster.music2/tlds/tld.0/fdispatchrc.MODEL.cfg b/config-model/src/test/cfg/search/compare/twoFeedTargetClusters/search/cluster.music2/tlds/tld.0/fdispatchrc.MODEL.cfg index ee8b6cdd963..487310efd77 100644 --- a/config-model/src/test/cfg/search/compare/twoFeedTargetClusters/search/cluster.music2/tlds/tld.0/fdispatchrc.MODEL.cfg +++ b/config-model/src/test/cfg/search/compare/twoFeedTargetClusters/search/cluster.music2/tlds/tld.0/fdispatchrc.MODEL.cfg @@ -16,7 +16,6 @@ partition 0 ptport 19119 transport "" transportnodelay true -transportdirectwrite false packetcompresslimit 1024 packetcompresslevel 3 packetcompresstype LZ4 diff --git a/config-model/src/test/integration/onnx/services.xml b/config-model/src/test/integration/onnx/services.xml new file mode 100644 index 00000000000..f623b2464fc --- /dev/null +++ b/config-model/src/test/integration/onnx/services.xml @@ -0,0 +1,5 @@ +<services> + <container version="1.0"> + + </container> +</services>
\ No newline at end of file diff --git a/config-model/src/test/integration/tensorflow/services.xml b/config-model/src/test/integration/tensorflow/services.xml new file mode 100644 index 00000000000..f623b2464fc --- /dev/null +++ b/config-model/src/test/integration/tensorflow/services.xml @@ -0,0 +1,5 @@ +<services> + <container version="1.0"> + + </container> +</services>
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java deleted file mode 100644 index c5fb4f575cf..00000000000 --- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java +++ /dev/null @@ -1,92 +0,0 @@ -package com.yahoo.config.model; - -import ai.vespa.models.evaluation.Model; -import ai.vespa.models.evaluation.ModelsEvaluator; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.io.IOUtils; -import com.yahoo.path.Path; -import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.model.VespaModel; -import com.yahoo.vespa.model.container.ContainerCluster; -import org.junit.After; -import org.junit.Test; -import org.xml.sax.SAXException; - -import java.io.IOException; -import java.util.Set; -import java.util.stream.Collectors; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - -/** - * @author bratseth - */ -public class ModelEvaluationTest { - - private static final Path appDir = Path.fromString("src/test/cfg/application/ml_serving"); - - @After - public void removeGeneratedModelFiles() { - IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); - } - - @Test - public void testMl_ServingApplication() throws SAXException, IOException { - ApplicationPackageTester tester = ApplicationPackageTester.create(appDir.toString()); - VespaModel model = new VespaModel(tester.app()); - assertHasMlModels(model); - - // At this point the expression is stored - copy application to another location which do not have a models dir - Path storedAppDir = appDir.append("copy"); - try { - storedAppDir.toFile().mkdirs(); - IOUtils.copy(appDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); - IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), - storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); - ApplicationPackageTester storedTester = ApplicationPackageTester.create(storedAppDir.toString()); - VespaModel storedModel = new VespaModel(storedTester.app()); - assertHasMlModels(storedModel); - } - finally { - IOUtils.recursiveDeleteDir(storedAppDir.toFile()); - } - } - - private void assertHasMlModels(VespaModel model) { - ContainerCluster cluster = model.getContainerClusters().get("container"); - RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); - cluster.getConfig(b); - RankProfilesConfig config = new RankProfilesConfig(b); - assertEquals(4, config.rankprofile().size()); - Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); - assertTrue(modelNames.contains("xgboost_2_2")); - assertTrue(modelNames.contains("mnist_softmax")); - assertTrue(modelNames.contains("mnist_softmax_saved")); - - ModelsEvaluator evaluator = new ModelsEvaluator(config); - - assertEquals(4, evaluator.models().size()); - Model xgboost = evaluator.models().get("xgboost_2_2"); - assertNotNull(xgboost); - assertNotNull(xgboost.evaluatorOf()); - assertNotNull(xgboost.evaluatorOf("xgboost_2_2")); - - Model onnx = evaluator.models().get("mnist_softmax"); - assertNotNull(onnx); - assertNotNull(onnx.evaluatorOf()); - assertNotNull(onnx.evaluatorOf("default")); - assertNotNull(onnx.evaluatorOf("default", "add")); - assertNotNull(onnx.evaluatorOf("default.add")); - assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add")); - assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add")); - - Model tensorflow = evaluator.models().get("mnist_softmax_saved"); - assertNotNull(tensorflow); - assertNotNull(tensorflow.evaluatorOf()); - assertNotNull(tensorflow.evaluatorOf("serving_default")); - assertNotNull(tensorflow.evaluatorOf("serving_default", "y")); - } - -} diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelNameCollisionTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelNameCollisionTest.java new file mode 100644 index 00000000000..08f18331d1c --- /dev/null +++ b/config-model/src/test/java/com/yahoo/config/model/ModelNameCollisionTest.java @@ -0,0 +1,43 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.config.model; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.vespa.model.VespaModel; +import org.junit.After; +import org.junit.Test; +import org.xml.sax.SAXException; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class ModelNameCollisionTest { + + private static final Path appDir = Path.fromString("src/test/cfg/application/ml_serving_name_collision"); + + @After + public void removeGeneratedModelFiles() { + IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + + @Test + public void testMl_ServingApplication() throws SAXException, IOException { + ApplicationPackageTester tester = ApplicationPackageTester.create(appDir.toString()); + try { + new VespaModel(tester.app()); + } + catch (IllegalArgumentException e) { + assertEquals("The models in " + + appDir + "/models/parent/mnist_softmax.onnx and " + + appDir + "/models/parent/mnist_softmax" + + " both resolve to the model name 'parent_mnist_softmax'", + e.getMessage()); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java index aa01070d296..056fc27f067 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java @@ -1,8 +1,4 @@ -/* - * // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - * - * - */ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java index e1ddd0c02ca..b13ffabda77 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java @@ -7,12 +7,14 @@ import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.yolean.Exceptions; import org.junit.Test; import java.util.Optional; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * @author bratseth diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java new file mode 100644 index 00000000000..df9a40d29e2 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java @@ -0,0 +1,197 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition; + +import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class RankingExpressionLoopDetectionTestCase { + + @Test + public void testSelfLoop() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field a type string { \n" + + " indexing: index \n" + + " }\n" + + " }\n" + + " \n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo\n" + + " }\n" + + " macro foo() {\n" + + " expression: foo\n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + try { + builder.build(); + fail("Excepted exception"); + } + catch (IllegalArgumentException e) { + assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> foo", + Exceptions.toMessageString(e)); + } + } + + @Test + public void testNestedLoop() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field a type string { \n" + + " indexing: index \n" + + " }\n" + + " }\n" + + " \n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo\n" + + " }\n" + + " macro foo() {\n" + + " expression: arg(5)\n" + + " }\n" + + " macro arg(a1) {\n" + + " expression: foo + a1*2\n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + try { + builder.build(); + fail("Excepted exception"); + } + catch (IllegalArgumentException e) { + assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> arg(5) -> foo", + Exceptions.toMessageString(e)); + } + } + + @Test + public void testSelfArgumentLoop() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field a type string { \n" + + " indexing: index \n" + + " }\n" + + " }\n" + + " \n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo\n" + + " }\n" + + " macro foo() {\n" + + " expression: arg(foo)\n" + + " }\n" + + " macro arg(a1) {\n" + + " expression: a1*2\n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + try { + builder.build(); + fail("Excepted exception"); + } + catch (IllegalArgumentException e) { + assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> arg(foo) -> foo", + Exceptions.toMessageString(e)); + } + } + + @Test + public void testNoLoopWithSameLocalArgument() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field a type string { \n" + + " indexing: index \n" + + " }\n" + + " }\n" + + " \n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo(3)\n" + + " }\n" + + " macro foo(a1) {\n" + + " expression: bar(3)\n" + + " }\n" + + " macro bar(a1) {\n" + + " expression: a1*2\n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + builder.build(); + } + + @Test + public void testNoLoopWithMultipleInvocations() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field a type string { \n" + + " indexing: index \n" + + " }\n" + + " }\n" + + " \n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo(3)\n" + + " }\n" + + " macro foo(a1) {\n" + + " expression: bar(3) + bar(a1)\n" + + " }\n" + + " macro bar(a1) {\n" + + " expression: a1*2\n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + builder.build(); + } + + @Test + public void testNoLoopWithBoundIdentifiers() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " }\n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo(bar(2))\n" + + " }\n" + + " macro foo(x) {\n" + + " expression: x * x\n" + + " }\n" + + " macro bar(x) {\n" + + " expression: x + x\n" + + " }\n" + + " }\n" + + "}\n"); + builder.build(); + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 414a77e9164..b046d60f948 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -1,27 +1,22 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.ml.ImportedModelTester; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.Optional; import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage; -import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; @@ -41,14 +36,36 @@ public class RankingExpressionWithOnnxTestCase { } @Test + public void testGlobalOnnxModel() throws IOException { + ImportedModelTester tester = new ImportedModelTester(name, applicationDir); + VespaModel model = tester.createVespaModel(); + tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L)); + tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedAppDir = applicationDir.append("copy"); + try { + storedAppDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); + VespaModel storedModel = storedTester.createVespaModel(); + tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L)); + tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L)); + } + finally { + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant(name + "_Variable", search, Optional.of(7840L)); } @Test @@ -68,8 +85,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant(name + "_Variable", search, Optional.of(7840L)); } @Test @@ -82,8 +97,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @@ -104,8 +117,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @@ -114,8 +125,6 @@ public class RankingExpressionWithOnnxTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(onnx('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @Test @@ -181,9 +190,6 @@ public class RankingExpressionWithOnnxTestCase { "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); - // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); try { @@ -200,8 +206,6 @@ public class RankingExpressionWithOnnxTestCase { searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test - assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.empty()); - assertLargeConstant( name + "_Variable", searchFromStored, Optional.empty()); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -232,7 +236,6 @@ public class RankingExpressionWithOnnxTestCase { assertNull("Constant overridden by macro is not added", search.search().rankingConstants().get( name + "_Variable")); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -245,38 +248,12 @@ public class RankingExpressionWithOnnxTestCase { searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", - searchFromStored.search().rankingConstants().get( name + "_Variable")); - assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.of(10L)); + searchFromStored.search().rankingConstants().get( name + "_Variable")); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); } } - /** - * Verifies that the constant with the given name exists, and - only if an expected size is given - - * that the content of the constant is available and has the expected size. - */ - private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { - try { - Path constantApplicationPackagePath = Path.fromString("models.generated/my_profile.mnist_softmax.onnx/constants").append(name + ".tbf"); - RankingConstant rankingConstant = search.search().rankingConstants().get(name); - assertEquals(name, rankingConstant.getName()); - assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); - - if (expectedSize.isPresent()) { - Path constantPath = applicationDir.append(constantApplicationPackagePath); - assertTrue("Constant file '" + constantPath + "' has been written", - constantPath.toFile().exists()); - Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), - GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); - assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); - } - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", new StoringApplicationPackage(applicationDir)); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 450c66e04ef..14632a568ea 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -15,27 +15,22 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.ml.ImportedModelTester; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; -import java.io.BufferedInputStream; import java.io.File; -import java.io.FileInputStream; -import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; -import java.io.InputStream; -import java.io.Reader; import java.io.UncheckedIOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; -import java.util.Iterator; import java.util.List; import java.util.Optional; -import java.util.stream.Collectors; +import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.*; /** @@ -56,12 +51,34 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test + public void testGlobalTensorFlowModel() throws IOException { + ImportedModelTester tester = new ImportedModelTester(name, applicationDir); + VespaModel model = tester.createVespaModel(); + assertLargeConstant(name + "_layer_Variable_1_read", model, Optional.of(10L)); + assertLargeConstant(name + "_layer_Variable_read", model, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedAppDir = applicationDir.append("copy"); + try { + storedAppDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); + VespaModel storedModel = storedTester.createVespaModel(); + tester.assertLargeConstant(name + "_layer_Variable_1_read", storedModel, Optional.of(10L)); + tester.assertLargeConstant(name + "_layer_Variable_read", storedModel, Optional.of(7840L)); + } + finally { + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + @Test public void testTensorFlowReference() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -71,8 +88,6 @@ public class RankingExpressionWithTensorFlowTestCase { "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -91,8 +106,6 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -105,8 +118,6 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -125,8 +136,6 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -134,8 +143,6 @@ public class RankingExpressionWithTensorFlowTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(tensorflow('mnist_softmax/saved'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -233,9 +240,6 @@ public class RankingExpressionWithTensorFlowTestCase { "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); - // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); try { @@ -250,10 +254,6 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", storedApplication); searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); - // Verify that the constants exists, but don't verify the content as we are not - // simulating file distribution in this test - assertLargeConstant(name + "_layer_Variable_1_read", searchFromStored, Optional.empty()); - assertLargeConstant(name + "_layer_Variable_read", searchFromStored, Optional.empty()); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -287,7 +287,6 @@ public class RankingExpressionWithTensorFlowTestCase { assertNull("Constant overridden by macro is not added", search.search().rankingConstants().get("mnist_softmax_saved_layer_Variable_read")); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -303,7 +302,6 @@ public class RankingExpressionWithTensorFlowTestCase { searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by macro is not added", searchFromStored.search().rankingConstants().get("mnist_softmax_saved_layer_Variable_read")); - assertLargeConstant(name + "_layer_Variable_1_read", searchFromStored, Optional.of(10L)); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -316,8 +314,6 @@ public class RankingExpressionWithTensorFlowTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(expression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -401,11 +397,11 @@ public class RankingExpressionWithTensorFlowTestCase { * Verifies that the constant with the given name exists, and - only if an expected size is given - * that the content of the constant is available and has the expected size. */ - private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { + private void assertLargeConstant(String constantName, VespaModel model, Optional<Long> expectedSize) { try { - Path constantApplicationPackagePath = Path.fromString("models.generated/my_profile.mnist_softmax_saved/constants").append(name + ".tbf"); - RankingConstant rankingConstant = search.search().rankingConstants().get(name); - assertEquals(name, rankingConstant.getName()); + Path constantApplicationPackagePath = Path.fromString("models.generated/" + name + "/constants").append(constantName + ".tbf"); + RankingConstant rankingConstant = model.rankingConstants().get(constantName); + assertEquals(constantName, rankingConstant.getName()); assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); if (expectedSize.isPresent()) { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/admin/DedicatedAdminV4Test.java b/config-model/src/test/java/com/yahoo/vespa/model/admin/DedicatedAdminV4Test.java index 7b586354394..27839765930 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/admin/DedicatedAdminV4Test.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/admin/DedicatedAdminV4Test.java @@ -5,10 +5,15 @@ import com.yahoo.cloud.config.LogforwarderConfig; import com.yahoo.cloud.config.SentinelConfig; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.deploy.DeployProperties; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.provision.Hosts; import com.yahoo.config.model.provision.InMemoryProvisioner; import com.yahoo.config.model.test.MockApplicationPackage; +import com.yahoo.config.provision.Environment; +import com.yahoo.config.provision.RegionName; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.Zone; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.admin.monitoring.Metric; import com.yahoo.vespa.model.admin.monitoring.MetricsConsumer; @@ -73,7 +78,8 @@ public class DedicatedAdminV4Test { assertHostContainsServices(model, "hosts/myhost0", "slobrok", "logd"); assertHostContainsServices(model, "hosts/myhost1", "slobrok", "logd"); - assertHostContainsServices(model, "hosts/myhost2", "logserver", "logd"); + // Note: A container is always added on logserver host + assertHostContainsServices(model, "hosts/myhost2", "logserver", "logd", "container"); Monitoring monitoring = model.getAdmin().getMonitoring(); assertEquals("vespa.routing", monitoring.getClustername()); @@ -155,7 +161,8 @@ public class DedicatedAdminV4Test { assertHostContainsServices(model, "hosts/myhost0", "logd", "logforwarder", "slobrok"); assertHostContainsServices(model, "hosts/myhost1", "logd", "logforwarder", "slobrok"); - assertHostContainsServices(model, "hosts/myhost2", "logd", "logforwarder", "logserver"); + // Note: A container is always added on logserver host + assertHostContainsServices(model, "hosts/myhost2", "logd", "logforwarder", "logserver", "container"); Set<String> configIds = model.getConfigIds(); // 1 logforwarder on each host @@ -183,6 +190,26 @@ public class DedicatedAdminV4Test { } } + @Test + public void testDedicatedLogserverInHostedVespa() throws IOException, SAXException { + String services = "<services>" + + " <admin version='4.0'>" + + " <logservers>" + + " <nodes count='1' dedicated='true'/>" + + " </logservers>" + + " </admin>" + + "</services>"; + + VespaModel model = createModel(hosts, services, new DeployState.Builder() + .zone(new Zone(SystemName.cd, Environment.dev, RegionName.defaultName())) + .properties(new DeployProperties.Builder() + .hostedVespa(true) + .build())); + assertEquals(1, model.getHosts().size()); + // Should create a container on the same node as logserver + assertHostContainsServices(model, "hosts/myhost0", "slobrok", "logd", "logserver", "container"); + } + private Set<String> serviceNames(VespaModel model, String hostname) { SentinelConfig config = model.getConfig(SentinelConfig.class, hostname); return config.service().stream().map(SentinelConfig.Service::name).collect(Collectors.toSet()); @@ -197,14 +224,18 @@ public class DedicatedAdminV4Test { } private VespaModel createModel(String hosts, String services) throws IOException, SAXException { + return createModel(hosts, services, new DeployState.Builder()); + } + + private VespaModel createModel(String hosts, String services, DeployState.Builder deployStateBuilder) throws IOException, SAXException { ApplicationPackage app = new MockApplicationPackage.Builder() .withHosts(hosts) .withServices(services) .build(); - return new VespaModel(new NullConfigModelRegistry(), - new DeployState.Builder().applicationPackage(app).modelHostProvisioner( - new InMemoryProvisioner(Hosts.readFrom(app.getHosts()), true)) - .build()); + return new VespaModel(new NullConfigModelRegistry(), deployStateBuilder + .applicationPackage(app) + .modelHostProvisioner(new InMemoryProvisioner(Hosts.readFrom(app.getHosts()), true)) + .build()); } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerClusterTest.java index 850fd91e151..756a0c53485 100755 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerClusterTest.java @@ -81,7 +81,7 @@ public class ContainerClusterTest { .zone(new Zone(SystemName.cd, Environment.test, RegionName.from("some-region"))) .build(); MockRoot root = new MockRoot("foo", state); - ContainerCluster cluster = new ContainerCluster(root, "container0", "container1", RankProfileList.empty); + ContainerCluster cluster = new ContainerCluster(root, "container0", "container1"); ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder(); cluster.getConfig(builder); ConfigserverConfig config = new ConfigserverConfig(builder); @@ -112,8 +112,8 @@ public class ContainerClusterTest { MockRoot root = new MockRoot("foo", state); ContainerCluster cluster = extraComponents.isPresent() - ? new ContainerCluster(root, "container0", "container1", extraComponents.get(), RankProfileList.empty) - : new ContainerCluster(root, "container0", "container1", RankProfileList.empty); + ? new ContainerCluster(root, "container0", "container1", extraComponents.get()) + : new ContainerCluster(root, "container0", "container1"); if (isCombinedCluster) cluster.setHostClusterId("test-content-cluster"); cluster.setMemoryPercentage(memoryPercentage); @@ -258,7 +258,7 @@ public class ContainerClusterTest { public void requireThatRoutingProviderIsDisabledForNonHosted() { DeployState state = new DeployState.Builder().properties(new DeployProperties.Builder().hostedVespa(false).build()).build(); MockRoot root = new MockRoot("foo", state); - ContainerCluster cluster = new ContainerCluster(root, "container0", "container1", RankProfileList.empty); + ContainerCluster cluster = new ContainerCluster(root, "container0", "container1"); RoutingProviderConfig.Builder builder = new RoutingProviderConfig.Builder(); cluster.getConfig(builder); RoutingProviderConfig config = new RoutingProviderConfig(builder); @@ -282,7 +282,7 @@ public class ContainerClusterTest { } private static ContainerCluster newContainerCluster() { - ContainerCluster cluster = new ContainerCluster(null, "subId", "name", RankProfileList.empty); + ContainerCluster cluster = new ContainerCluster(null, "subId", "name"); addContainer(cluster, "c1", "host-c1"); addContainer(cluster, "c2", "host-c2"); return cluster; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/MultipleRestApisTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/MultipleRestApisTest.java deleted file mode 100644 index d36ab74c6f1..00000000000 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/MultipleRestApisTest.java +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.model.container.jersey.xml; - -import com.yahoo.component.ComponentId; -import com.yahoo.config.model.builder.xml.test.DomBuilderTest; -import com.yahoo.container.ComponentsConfig; -import com.yahoo.container.di.config.JerseyBundlesConfig; -import com.yahoo.container.jdisc.JdiscBindingsConfig; -import com.yahoo.vespa.model.container.jersey.JerseyHandler; -import com.yahoo.vespa.model.container.jersey.RestApi; -import com.yahoo.vespa.model.container.jersey.RestApiContext; -import com.yahoo.vespa.model.container.xml.ContainerModelBuilderTestBase; -import org.junit.Before; -import org.junit.Test; - -import java.util.Map; - -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.hasItems; -import static org.hamcrest.CoreMatchers.not; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.nullValue; -import static org.hamcrest.core.Is.is; -import static org.junit.Assert.assertTrue; - -/** - * @author bjorncs - */ -public class MultipleRestApisTest extends ContainerModelBuilderTestBase { - - private static final String CLUSTER_ID = "container"; - private static final String PATH_1 = "rest_1"; - private static final String PATH_2 = "rest_2"; - private static final String HTTP_BINDING_1 = "http://*/" + PATH_1 + "/*"; - private static final String HTTPS_BINDING_1 = "https://*/" + PATH_1 + "/*"; - private static final String HTTP_BINDING_2 = "http://*/" + PATH_2 + "/*"; - private static final String HTTPS_BINDING_2 = "https://*/" + PATH_2 + "/*"; - private static final String HANDLER_ID_1 = JerseyHandler.CLASS + "-" + PATH_1; - private static final String HANDLER_ID_2 = JerseyHandler.CLASS + "-" + PATH_2; - private static final String REST_API_CONTEXT_ID_1 = RestApiContext.CONTAINER_CLASS + "-" + PATH_1; - private static final String REST_API_CONTEXT_ID_2 = RestApiContext.CONTAINER_CLASS + "-" + PATH_2; - private static final String REST_API_XML = - "<container version=\"1.0\" id=\"" + CLUSTER_ID + "\">\n" + - " <rest-api path=\"" + PATH_1 + "\">\n" + - " <components bundle=\"bundle1\" />\n" + - " </rest-api>\n" + - " <rest-api path=\"" + PATH_2 + "\">\n" + - " <components bundle=\"bundle2\" />\n" + - " </rest-api>\n" + - "</container>"; - - - private JerseyHandler handler1; - private JerseyHandler handler2; - private Map<ComponentId, RestApi> restApis; - - @Before - public void setup() throws Exception { - createModel(root, DomBuilderTest.parse(REST_API_XML)); - handler1 = (JerseyHandler)getContainerComponentNested(CLUSTER_ID, HANDLER_ID_1); - handler2 = (JerseyHandler)getContainerComponentNested(CLUSTER_ID, HANDLER_ID_2); - restApis = getContainerCluster(CLUSTER_ID).getRestApiMap(); - } - - @Test - public void cluster_has_all_rest_apis() { - assertThat(restApis.size(), is(2)); - } - - @Test - public void rest_apis_have_path_as_component_id() { - assertTrue(restApis.get(ComponentId.fromString(PATH_1)) instanceof RestApi); - assertTrue(restApis.get(ComponentId.fromString(PATH_2)) instanceof RestApi); - } - - @Test - public void jersey_handler_has_correct_bindings() { - assertThat(handler1, not(nullValue())); - assertThat(handler1.getServerBindings(), hasItems(HTTP_BINDING_1, HTTPS_BINDING_1)); - - assertThat(handler2, not(nullValue())); - assertThat(handler2.getServerBindings(), hasItems(HTTP_BINDING_2, HTTPS_BINDING_2)); - } - - @Test - public void jersey_bindings_are_included_in_config() { - JdiscBindingsConfig config = root.getConfig(JdiscBindingsConfig.class, CLUSTER_ID); - assertThat(config.handlers(HANDLER_ID_1).serverBindings(), hasItems(HTTP_BINDING_1, HTTPS_BINDING_1)); - assertThat(config.handlers(HANDLER_ID_2).serverBindings(), hasItems(HTTP_BINDING_2, HTTPS_BINDING_2)); - } - - - @Test - public void jersey_handler_for_each_rest_api_is_included_in_components_config() { - ComponentsConfig config = root.getConfig(ComponentsConfig.class, CLUSTER_ID); - assertThat(config.toString(), containsString(".id \"" + HANDLER_ID_1 + "\"")); - assertThat(config.toString(), containsString(".id \"" + HANDLER_ID_2 + "\"")); - } - - @Test - public void jersey_bundles_component_for_each_rest_api_is_included_in_components_config() { - - ComponentsConfig config = root.getConfig(ComponentsConfig.class, CLUSTER_ID); - assertThat(config.toString(), containsString(".id \"" + REST_API_CONTEXT_ID_1 + "\"")); - assertThat(config.toString(), containsString(".id \"" + REST_API_CONTEXT_ID_2 + "\"")); - } - - @Test - public void each_rest_api_has_correct_bundle() { - RestApiContext restApiContext1 = restApis.get(ComponentId.fromString(PATH_1)).getContext(); - RestApiContext restApiContext2 = restApis.get(ComponentId.fromString(PATH_2)).getContext(); - - JerseyBundlesConfig bundlesConfig1 = root.getConfig(JerseyBundlesConfig.class, restApiContext1.getConfigId()); - assertThat(bundlesConfig1.toString(), containsString("bundle1")); - assertThat(bundlesConfig1.toString(), not(containsString("bundle2"))); - - JerseyBundlesConfig bundlesConfig2 = root.getConfig(JerseyBundlesConfig.class, restApiContext2.getConfigId()); - assertThat(bundlesConfig2.toString(), containsString("bundle2")); - assertThat(bundlesConfig2.toString(), not(containsString("bundle1"))); - } -} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/RestApiTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/RestApiTest.java index 4b28dfa0b9d..503b38b79b4 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/RestApiTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/RestApiTest.java @@ -2,28 +2,24 @@ package com.yahoo.vespa.model.container.jersey.xml; import com.yahoo.component.ComponentId; -import com.yahoo.config.model.builder.xml.test.DomBuilderTest; +import com.yahoo.config.model.test.TestUtil; import com.yahoo.container.ComponentsConfig; -import com.yahoo.container.config.jersey.JerseyInitConfig; import com.yahoo.container.di.config.JerseyBundlesConfig; -import com.yahoo.container.di.config.JerseyInjectionConfig; -import com.yahoo.container.jdisc.JdiscBindingsConfig; +import com.yahoo.jdisc.http.ServletPathsConfig; import com.yahoo.vespa.model.container.component.Component; -import com.yahoo.vespa.model.container.component.Handler; -import com.yahoo.vespa.model.container.jersey.JerseyHandler; +import com.yahoo.vespa.model.container.jersey.Jersey2Servlet; import com.yahoo.vespa.model.container.jersey.RestApi; import com.yahoo.vespa.model.container.jersey.RestApiContext; import com.yahoo.vespa.model.container.xml.ContainerModelBuilderTestBase; -import org.junit.Ignore; +import org.junit.Before; import org.junit.Test; +import org.w3c.dom.Element; import java.util.HashSet; import java.util.Set; import java.util.stream.Collectors; import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.hasItem; -import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.nullValue; @@ -32,103 +28,86 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; /** + * @author gjoranv * @author bjorncs */ public class RestApiTest extends ContainerModelBuilderTestBase { - private static final String Path = "rest/api"; - private static final String HttpBinding = "http://*/" + Path + "/*"; - private static final String HttpsBinding = "https://*/" + Path + "/*"; - private static final String HandlerId = JerseyHandler.CLASS + "-" + RestApi.idFromPath(Path); - private static final String RestApiContextId = RestApiContext.CONTAINER_CLASS + "-" + RestApi.idFromPath(Path); - private static final String InjectedComponentId = "injectedHandler"; - - private static final String ClusterId = "container"; - - private static final String restApiXml = - "<container version=\"1.0\" id=\"" + ClusterId + "\" jetty=\"true\">\n" + - " <rest-api path=\"" + Path + "\">\n" + - " <components bundle=\"my-jersey-bundle:1.0\">\n" + - " <package>com.yahoo.foo</package>\n" + - " </components>\n" + - " </rest-api>\n" + - " <handler id=\"" + InjectedComponentId + "\" />\n" + - "</container>"; + private static final String PATH = "rest/api"; + private static final String REST_API_CONTEXT_ID = RestApiContext.CONTAINER_CLASS + "-" + RestApi.idFromPath(PATH); + private static final String INJECTED_COMPONENT_ID = "injectedHandler"; + private static final String CLUSTER_ID = "container"; + + private static final Element restApiXml = TestUtil.parse( + "<container version=\"1.0\" id=\"" + CLUSTER_ID + "\">", + " <rest-api path=\"" + PATH + "\">", + " <components bundle=\"my-jersey-bundle:1.0\">", + " <package>com.yahoo.foo</package>", + " </components>", + " </rest-api>", + " <handler id=\"" + INJECTED_COMPONENT_ID + "\" />", + "</container>"); private RestApi restApi; - private JerseyHandler handler; + private Jersey2Servlet servlet; private RestApiContext context; + @Before public void setup() throws Exception { - createModel(root, DomBuilderTest.parse(restApiXml)); + createModel(root, restApiXml); root.validate(); - getContainerCluster(ClusterId).prepare(); - restApi = getContainerCluster(ClusterId).getRestApiMap().values().iterator().next(); - handler = (JerseyHandler) getContainerComponentNested(ClusterId, HandlerId); + getContainerCluster(CLUSTER_ID).prepare(); + restApi = getContainerCluster(CLUSTER_ID).getRestApiMap().values().iterator().next(); + servlet = restApi.getJersey2Servlet(); context = restApi.getContext(); } @Test - public void jersey_handler_has_correct_bindings() throws Exception { - setup(); - assertThat(handler, not(nullValue())); - assertThat(handler.getServerBindings(), hasItems(HttpBinding, HttpsBinding)); + public void jersey2_servlet_has_correct_binding_path() { + assertThat(servlet, not(nullValue())); + assertThat(servlet.bindingPath, is(PATH + "/*")); } @Test - public void jersey_bindings_are_included_in_config() throws Exception { - setup(); - JdiscBindingsConfig config = root.getConfig(JdiscBindingsConfig.class, ClusterId); - assertThat(config.handlers(HandlerId).serverBindings(), hasItems(HttpBinding, HttpsBinding)); + public void jersey2_servlet_has_correct_bundle_spec() { + assertThat(servlet.model.bundleInstantiationSpec.bundle.stringValue(), is(Jersey2Servlet.BUNDLE)); } @Test - public void jersey_handler_has_correct_bundle_spec() throws Exception { - setup(); - assertThat(handler.model.bundleInstantiationSpec.bundle.stringValue(), is(JerseyHandler.BUNDLE)); + public void rest_api_path_is_included_in_servlet_config() { + ServletPathsConfig config = root.getConfig(ServletPathsConfig.class, servlet.getConfigId()); + assertThat(config.servlets(servlet.getComponentId().stringValue()).path(), is(PATH + "/*")); } @Test - public void config_has_correct_jersey_mapping() throws Exception { - setup(); - JerseyInitConfig config = root.getConfig(JerseyInitConfig.class, handler.getConfigId()); - assertThat(config.jerseyMapping(), is(Path)); - } - - @Test - public void resource_bundles_are_included_in_config() throws Exception { - setup(); + public void resource_bundles_are_included_in_config() { JerseyBundlesConfig config = root.getConfig(JerseyBundlesConfig.class, context.getConfigId()); assertThat(config.bundles().size(), is(1)); assertThat(config.bundles(0).spec(), is("my-jersey-bundle:1.0")); } @Test - public void packages_to_scan_are_included_in_config() throws Exception { - setup(); + public void packages_to_scan_are_included_in_config() { JerseyBundlesConfig config = root.getConfig(JerseyBundlesConfig.class, context.getConfigId()); assertThat(config.bundles(0).packages(), contains("com.yahoo.foo")); } @Test - public void jersey_handler_is_included_in_components_config() throws Exception { - setup(); - ComponentsConfig config = root.getConfig(ComponentsConfig.class, ClusterId); - assertThat(config.toString(), containsString(".id \"" + HandlerId + "\"")); + public void jersey2_servlet_is_included_in_components_config() { + ComponentsConfig config = root.getConfig(ComponentsConfig.class, CLUSTER_ID); + assertThat(config.toString(), containsString(".id \"" + servlet.getComponentId().stringValue() + "\"")); } @Test - public void restApiContext_is_included_in_components_config() throws Exception { - setup(); - ComponentsConfig config = root.getConfig(ComponentsConfig.class, ClusterId); - assertThat(config.toString(), containsString(".id \"" + RestApiContextId + "\"")); + public void restApiContext_is_included_in_components_config() { + ComponentsConfig config = root.getConfig(ComponentsConfig.class, CLUSTER_ID); + assertThat(config.toString(), containsString(".id \"" + REST_API_CONTEXT_ID + "\"")); } @Test public void all_non_restApi_components_are_injected_to_RestApiContext() throws Exception { - setup(); - ComponentsConfig componentsConfig = root.getConfig(ComponentsConfig.class, ClusterId); + ComponentsConfig componentsConfig = root.getConfig(ComponentsConfig.class, CLUSTER_ID); - Set<ComponentId> clusterChildrenComponentIds = getContainerCluster(ClusterId).getAllComponents().stream() + Set<ComponentId> clusterChildrenComponentIds = getContainerCluster(CLUSTER_ID).getAllComponents().stream() .map(Component::getComponentId) .collect(Collectors.toSet()); @@ -136,7 +115,7 @@ public class RestApiTest extends ContainerModelBuilderTestBase { .map(child -> ((Component<?, ?>) child).getComponentId()) .collect(Collectors.toSet()); - //TODO: Review: replace with filtering against RestApiContext.isCycleGeneratingComponent + //TODO: try replacing with filtering against RestApiContext.isCycleGeneratingComponent ComponentId cycleInducingComponents = ComponentId.fromString("com.yahoo.container.handler.observability.ApplicationStatusHandler"); Set<ComponentId> expectedInjectedConfigIds = new HashSet<>(clusterChildrenComponentIds); @@ -165,49 +144,4 @@ public class RestApiTest extends ContainerModelBuilderTestBase { .get(); } - @Ignore // TODO: use for naming components instead - @Test - public void jdisc_components_can_be_injected() throws Exception { - setup(); - JerseyInjectionConfig config = root.getConfig(JerseyInjectionConfig.class, context.getConfigId()); - assertThat(config.inject(0).instance(), is("injectedHandler")); - assertThat(config.inject(0).forClass(), is("com.yahoo.handler.Handler")); - } - - @Ignore // TODO: use for naming a non-existent component instead - @Test(expected = IllegalArgumentException.class) - public void injecting_non_existent_component() throws Exception { - String restApiXml = - "<container version=\"1.0\" id=\"" + ClusterId + "\">\n" + - " <rest-api path=\"" + Path + "\">\n" + - " <components bundle=\"my-jersey-bundle:1.0\" />\n" + - " <inject jdisc-component=\"non-existent\" for-class=\"foo\" />\n" + - " </rest-api>\n" + - "</container>"; - createModel(root, DomBuilderTest.parse(restApiXml)); - root.validate(); - } - - @Test - public void legacy_syntax_should_produce_valid_model() throws Exception { - String legacyXml = - "<container version=\"1.0\" >\n" + - " <handler id=\"" + JerseyHandler.CLASS + "\" >\n" + - " <binding>" + HttpBinding + "</binding>\n" + - " <config name=\"jdisc.jersey.jersey-handler\">\n" + - " <jerseyMapping>jersey</jerseyMapping>\n" + - " </config>\n" + - " </handler>\n" + - "</container>"; - - createModel(root, DomBuilderTest.parse(legacyXml)); - - Handler<?> handler = (Handler<?>) getContainerComponent("container", JerseyHandler.CLASS); - assertThat(handler, not(nullValue())); - assertThat(handler.getServerBindings(), hasItem(HttpBinding)); - - JdiscBindingsConfig bindingsConfig = root.getConfig(JdiscBindingsConfig.class, ClusterId); - assertThat(bindingsConfig.handlers(JerseyHandler.CLASS).serverBindings(), hasItem(HttpBinding)); - } - } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java index e46e736dcd6..6a5611a7279 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java @@ -78,12 +78,4 @@ public abstract class ContainerModelBuilderTestBase { ComponentId.fromString(componentId)); } - // TODO: will not work with multiple instances of the same class - public Component<?, ?> getContainerComponentNested(String clusterId, String componentId) { - ComponentId id = ComponentId.fromString(componentId); - for (Component<?,?> component : getContainerCluster(clusterId).getAllComponents()) - if (id.equals(component.getComponentId())) - return component; - return null; - } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JettyContainerModelBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JettyContainerModelBuilderTest.java index 54c4aabf44c..ceff9f3d4bb 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JettyContainerModelBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JettyContainerModelBuilderTest.java @@ -3,34 +3,30 @@ package com.yahoo.vespa.model.container.xml; import com.yahoo.config.model.builder.xml.test.DomBuilderTest; import com.yahoo.container.ComponentsConfig; -import com.yahoo.container.bundle.BundleInstantiationSpecification; import com.yahoo.container.jdisc.FilterBindingsProvider; import com.yahoo.jdisc.http.ConnectorConfig; -import com.yahoo.jdisc.http.ssl.DefaultSslKeyStoreConfigurator; -import com.yahoo.jdisc.http.ssl.DefaultSslTrustStoreConfigurator; import com.yahoo.vespa.model.container.ContainerCluster; import com.yahoo.vespa.model.container.component.SimpleComponent; import com.yahoo.vespa.model.container.http.ConnectorFactory; import com.yahoo.vespa.model.container.http.JettyHttpServer; +import com.yahoo.vespa.model.container.http.ssl.DefaultSslProvider; import org.junit.Test; import org.w3c.dom.Element; -import org.xml.sax.SAXException; -import java.io.IOException; -import java.util.Arrays; import java.util.List; -import java.util.Set; +import java.util.Optional; import static com.yahoo.jdisc.http.ConnectorConfig.Ssl.KeyStoreType; import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.nullValue; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; /** * @author einarmr + * @author mortent */ public class JettyContainerModelBuilderTest extends ContainerModelBuilderTestBase { @@ -192,60 +188,113 @@ public class JettyContainerModelBuilderTest extends ContainerModelBuilderTestBas } @Test - public void ssl_keystore_and_truststore_configurator_can_be_overriden() throws IOException, SAXException { + public void verify_that_ssl_element_generates_connector_config_and_inject_provider_component() { Element clusterElem = DomBuilderTest.parse( "<jdisc id='default' version='1.0' jetty='true'>", - " <http>", - " <server port='9000' id='foo'>", - " <ssl-keystore-configurator class='com.yahoo.MySslKeyStoreConfigurator' bundle='mybundle'/>", - " <ssl-truststore-configurator class='com.yahoo.MySslTrustStoreConfigurator' bundle='mybundle'/>", - " </server>", - " <server port='9001' id='bar'/>", - " </http>", + " <http>", + " <server port='9000' id='minimal'>", + " <ssl>", + " <private-key-file>/foo/key</private-key-file>", + " <certificate-file>/foo/cert</certificate-file>", + " </ssl>", + " </server>", + " <server port='9001' id='with-cacerts'>", + " <ssl>", + " <private-key-file>/foo/key</private-key-file>", + " <certificate-file>/foo/cert</certificate-file>", + " <ca-certificates-file>/foo/cacerts</ca-certificates-file>", + " </ssl>", + " </server>", + " <server port='9002' id='need-client-auth'>", + " <ssl>", + " <private-key-file>/foo/key</private-key-file>", + " <certificate-file>/foo/cert</certificate-file>", + " <client-authentication>need</client-authentication>", + " </ssl>", + " </server>", + " </http>", nodesXml, + "", "</jdisc>"); + createModel(root, clusterElem); + ConnectorConfig minimalCfg = root.getConfig(ConnectorConfig.class, "default/http/jdisc-jetty/minimal/default-ssl-provider@minimal"); + assertTrue(minimalCfg.ssl().enabled()); + assertThat(minimalCfg.ssl().privateKeyFile(), is(equalTo("/foo/key"))); + assertThat(minimalCfg.ssl().certificateFile(), is(equalTo("/foo/cert"))); + assertThat(minimalCfg.ssl().caCertificateFile(), is(equalTo(""))); + assertThat(minimalCfg.ssl().clientAuth(), is(equalTo(ConnectorConfig.Ssl.ClientAuth.Enum.DISABLED))); + + ConnectorConfig withCaCerts = root.getConfig(ConnectorConfig.class, "default/http/jdisc-jetty/with-cacerts/default-ssl-provider@with-cacerts"); + assertTrue(withCaCerts.ssl().enabled()); + assertThat(withCaCerts.ssl().privateKeyFile(), is(equalTo("/foo/key"))); + assertThat(withCaCerts.ssl().certificateFile(), is(equalTo("/foo/cert"))); + assertThat(withCaCerts.ssl().caCertificateFile(), is(equalTo("/foo/cacerts"))); + assertThat(withCaCerts.ssl().clientAuth(), is(equalTo(ConnectorConfig.Ssl.ClientAuth.Enum.DISABLED))); + + ConnectorConfig needClientAuth = root.getConfig(ConnectorConfig.class, "default/http/jdisc-jetty/need-client-auth/default-ssl-provider@need-client-auth"); + assertTrue(needClientAuth.ssl().enabled()); + assertThat(needClientAuth.ssl().privateKeyFile(), is(equalTo("/foo/key"))); + assertThat(needClientAuth.ssl().certificateFile(), is(equalTo("/foo/cert"))); + assertThat(needClientAuth.ssl().caCertificateFile(), is(equalTo(""))); + assertThat(needClientAuth.ssl().clientAuth(), is(equalTo(ConnectorConfig.Ssl.ClientAuth.Enum.NEED_AUTH))); + ContainerCluster cluster = (ContainerCluster) root.getChildren().get("default"); List<ConnectorFactory> connectorFactories = cluster.getChildrenByTypeRecursive(ConnectorFactory.class); - { - ConnectorFactory firstConnector = connectorFactories.get(0); - assertConnectorHasInjectedComponents(firstConnector, "ssl-keystore-configurator@foo", "ssl-truststore-configurator@foo"); - assertComponentHasClassNameAndBundle(getChildComponent(firstConnector, 0), - "com.yahoo.MySslKeyStoreConfigurator", - "mybundle"); - assertComponentHasClassNameAndBundle(getChildComponent(firstConnector, 1), - "com.yahoo.MySslTrustStoreConfigurator", - "mybundle"); - } - { - ConnectorFactory secondConnector = connectorFactories.get(1); - assertConnectorHasInjectedComponents(secondConnector, "ssl-keystore-configurator@bar", "ssl-truststore-configurator@bar"); - assertComponentHasClassNameAndBundle(getChildComponent(secondConnector, 0), - DefaultSslKeyStoreConfigurator.class.getName(), - "jdisc_http_service"); - assertComponentHasClassNameAndBundle(getChildComponent(secondConnector, 1), - DefaultSslTrustStoreConfigurator.class.getName(), - "jdisc_http_service"); - } + connectorFactories.forEach(connectorFactory -> assertChildComponentExists(connectorFactory, DefaultSslProvider.COMPONENT_CLASS)); } - private static void assertConnectorHasInjectedComponents(ConnectorFactory connectorFactory, String... componentNames) { - Set<String> injectedComponentIds = connectorFactory.getInjectedComponentIds(); - assertThat(injectedComponentIds.size(), equalTo(componentNames.length)); - Arrays.stream(componentNames) - .forEach(name -> assertThat(injectedComponentIds, hasItem(name))); + @Test + public void verify_tht_ssl_provider_configuration_configures_correct_config() { + Element clusterElem = DomBuilderTest.parse( + "<jdisc id='default' version='1.0' jetty='true'>", + " <http>", + " <server port='9000' id='ssl'>", + " <ssl-provider class='com.yahoo.CustomSslProvider' bundle='mybundle'/>", + " </server>", + " </http>", + nodesXml, + "", + "</jdisc>"); + + createModel(root, clusterElem); + ConnectorConfig sslProvider = root.getConfig(ConnectorConfig.class, "default/http/jdisc-jetty/ssl/ssl-provider@ssl"); + + assertTrue(sslProvider.ssl().enabled()); + + ContainerCluster cluster = (ContainerCluster) root.getChildren().get("default"); + List<ConnectorFactory> connectorFactories = cluster.getChildrenByTypeRecursive(ConnectorFactory.class); + ConnectorFactory connectorFactory = connectorFactories.get(0); + assertChildComponentExists(connectorFactory, "com.yahoo.CustomSslProvider"); } - private static SimpleComponent getChildComponent(ConnectorFactory connectorFactory, int index) { - return connectorFactory.getChildrenByTypeRecursive(SimpleComponent.class).get(index); + @Test + public void verify_that_container_factory_sees_same_config(){ + Element clusterElem = DomBuilderTest.parse( + "<jdisc id='default' version='1.0' jetty='true'>", + " <http>", + " <server port='9000' id='ssl'>", + " <ssl>", + " <private-key-file>/foo/key</private-key-file>", + " <certificate-file>/foo/cert</certificate-file>", + " </ssl>", + " </server>", + " </http>", + nodesXml, + "", + "</jdisc>"); + + createModel(root, clusterElem); + ConnectorConfig sslProvider = root.getConfig(ConnectorConfig.class, "default/http/jdisc-jetty/ssl"); + assertTrue(sslProvider.ssl().enabled()); } - private static void assertComponentHasClassNameAndBundle(SimpleComponent simpleComponent, - String className, - String bundleName) { - BundleInstantiationSpecification spec = simpleComponent.model.bundleInstantiationSpec; - assertThat(spec.classId.toString(), is(className)); - assertThat(spec.bundle.toString(), is(bundleName)); + private static void assertChildComponentExists(ConnectorFactory connectorFactory, String className) { + Optional<SimpleComponent> simpleComponent = connectorFactory.getChildren().values().stream() + .map(z -> (SimpleComponent) z) + .filter(component -> component.getClassId().stringValue().equals(className)) + .findFirst(); + assertTrue(simpleComponent.isPresent()); } private void assertJettyServerInConfig() { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java new file mode 100644 index 00000000000..2ae629562d0 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -0,0 +1,71 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.yahoo.config.model.ApplicationPackageTester; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.model.VespaModel; +import org.xml.sax.SAXException; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Optional; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; + +/** + * Helper for testing of imported models. + * More duplicated functionality across tests on imported models should be moved here + * + * @author bratseth + */ +public class ImportedModelTester { + + private final String modelName; + private final Path applicationDir; + + public ImportedModelTester(String modelName, Path applicationDir) { + this.modelName = modelName; + this.applicationDir = applicationDir; + } + + public VespaModel createVespaModel() { + try { + return new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app()); + } + catch (SAXException | IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Verifies that the constant with the given name exists, and - only if an expected size is given - + * that the content of the constant is available and has the expected size. + */ + public void assertLargeConstant(String constantName, VespaModel model, Optional<Long> expectedSize) { + try { + Path constantApplicationPackagePath = Path.fromString("models.generated/" + modelName + "/constants").append(constantName + ".tbf"); + RankingConstant rankingConstant = model.rankingConstants().get(constantName); + assertEquals(constantName, rankingConstant.getName()); + assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); + + if (expectedSize.isPresent()) { + Path constantPath = applicationDir.append(constantApplicationPackagePath); + assertTrue("Constant file '" + constantPath + "' has been written", + constantPath.toFile().exists()); + Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); + assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); + } + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java new file mode 100644 index 00000000000..b7b3fc99e20 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -0,0 +1,145 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import ai.vespa.models.evaluation.Model; +import ai.vespa.models.evaluation.ModelsEvaluator; +import ai.vespa.models.evaluation.RankProfilesConfigImporter; +import com.yahoo.component.ComponentId; +import com.yahoo.config.FileReference; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.container.ContainerCluster; +import org.junit.Test; + +import java.io.IOException; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** + * @author bratseth + */ +public class ModelEvaluationTest { + + @Test + public void testMl_serving() throws IOException { + Path appDir = Path.fromString("src/test/cfg/application/ml_serving"); + Path storedAppDir = appDir.append("copy"); + try { + ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir); + assertHasMlModels(tester.createVespaModel()); + + // At this point the expression is stored - copy application to another location which do not have a models dir + storedAppDir.toFile().mkdirs(); + IOUtils.copy(appDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + ImportedModelTester storedTester = new ImportedModelTester("ml_serving", storedAppDir); + assertHasMlModels(storedTester.createVespaModel()); + } + finally { + IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + /** Tests that we do not load models (which will waste memory) when not requested */ + @Test + public void testMl_serving_not_activated() throws IOException { + Path appDir = Path.fromString("src/test/cfg/application/ml_serving_not_activated"); + try { + ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir); + VespaModel model = tester.createVespaModel(); + ContainerCluster cluster = model.getContainerClusters().get("container"); + assertNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName()))); + + RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); + cluster.getConfig(b); + RankProfilesConfig config = new RankProfilesConfig(b); + + assertEquals(0, config.rankprofile().size()); + } + finally { + IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + } + + private void assertHasMlModels(VespaModel model) { + ContainerCluster cluster = model.getContainerClusters().get("container"); + assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName()))); + + RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); + cluster.getConfig(b); + RankProfilesConfig config = new RankProfilesConfig(b); + + RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder(); + cluster.getConfig(cb); + RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb); + + assertEquals(4, config.rankprofile().size()); + Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); + assertTrue(modelNames.contains("xgboost_2_2")); + assertTrue(modelNames.contains("mnist_saved")); + assertTrue(modelNames.contains("mnist_softmax")); + assertTrue(modelNames.contains("mnist_softmax_saved")); + + ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null)) + .importFrom(config, constantsConfig)); + + assertEquals(4, evaluator.models().size()); + + Model xgboost = evaluator.models().get("xgboost_2_2"); + assertNotNull(xgboost); + assertNotNull(xgboost.evaluatorOf()); + assertNotNull(xgboost.evaluatorOf("xgboost_2_2")); + + Model tensorflow_mnist = evaluator.models().get("mnist_saved"); + assertNotNull(tensorflow_mnist); + assertNotNull(tensorflow_mnist.evaluatorOf("serving_default")); + assertNotNull(tensorflow_mnist.evaluatorOf("serving_default", "y")); + assertNotNull(tensorflow_mnist.evaluatorOf("serving_default.y")); + assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default.y")); + assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default", "y")); + + Model onnx_mnist_softmax = evaluator.models().get("mnist_softmax"); + assertNotNull(onnx_mnist_softmax); + assertNotNull(onnx_mnist_softmax.evaluatorOf()); + assertNotNull(onnx_mnist_softmax.evaluatorOf("default")); + assertNotNull(onnx_mnist_softmax.evaluatorOf("default", "add")); + assertNotNull(onnx_mnist_softmax.evaluatorOf("default.add")); + assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add")); + assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add")); + + Model tensorflow_mnist_softmax = evaluator.models().get("mnist_softmax_saved"); + assertNotNull(tensorflow_mnist_softmax); + assertNotNull(tensorflow_mnist_softmax.evaluatorOf()); + assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default")); + assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default", "y")); + } + + // We don't have function file distribution so just return empty tensor constants + private static class ToleratingMissingConstantFilesRankProfilesConfigImporter extends RankProfilesConfigImporter { + + public ToleratingMissingConstantFilesRankProfilesConfigImporter(FileAcquirer fileAcquirer) { + super(fileAcquirer); + } + + protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) { + return Tensor.from(type, "{}"); + } + + } + +} diff --git a/config-model/src/test/schema-test-files/services.xml b/config-model/src/test/schema-test-files/services.xml index e740e7d86b0..632abe68ab7 100644 --- a/config-model/src/test/schema-test-files/services.xml +++ b/config-model/src/test/schema-test-files/services.xml @@ -112,15 +112,23 @@ </request-chain> </filtering> - <server port="4080" id="myServer"> - <ssl-keystore-configurator class="com.yahoo.MySslKeyStoreConfigurator" bundle="mybundle" /> - <ssl-truststore-configurator class="com.yahoo.MySslTrustStoreConfigurator" bundle="mybundle" /> - </server> + <server port="4080" id="myServer"/> <server port="4081" id="anotherServer"> <config name="container.jdisc.config.http-server"> <maxChunkSize>9999</maxChunkSize> </config> </server> + <server port="4082" id="defaultSsl"> + <ssl> + <private-key-file>/foo/key</private-key-file> + <certificate-file>/foo/cert</certificate-file> + <ca-certificates-file>/foo/cacerts</ca-certificates-file> + <client-authentication>want</client-authentication> + </ssl> + </server> + <server port="4083" id="sslProvider"> + <ssl-provider class="com.yahoo.MySslProvider" bundle="mybundle"/> + </server> </http> <accesslog type='json' diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java b/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java index 5204da08307..6df617ea335 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java @@ -65,12 +65,6 @@ public final class Capacity { return fromNodeCount(capacity, Optional.empty(), false, true); } - // TODO: Remove after July 2018 - @Deprecated - public static Capacity fromNodeCount(int nodeCount, Optional<String> flavor, boolean required) { - return new Capacity(nodeCount, flavor, required, true, NodeType.tenant); - } - public static Capacity fromNodeCount(int nodeCount, Optional<String> flavor, boolean required, boolean canFail) { return new Capacity(nodeCount, flavor, required, canFail, NodeType.tenant); } diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/InstanceName.java b/config-provisioning/src/main/java/com/yahoo/config/provision/InstanceName.java index 0f1b298ba83..703528e5d33 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/InstanceName.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/InstanceName.java @@ -46,6 +46,10 @@ public class InstanceName implements Comparable<InstanceName> { return equals(InstanceName.defaultName()); } + public boolean isTester() { + return value().endsWith("-t"); + } + public String value() { return instanceName; } @Override diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/ProvisionInfo.java b/config-provisioning/src/main/java/com/yahoo/config/provision/ProvisionInfo.java deleted file mode 100644 index ca8d531634b..00000000000 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/ProvisionInfo.java +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.config.provision; - -import com.yahoo.slime.ArrayTraverser; -import com.yahoo.slime.Inspector; -import com.yahoo.vespa.config.SlimeUtils; - -import java.util.LinkedHashSet; -import java.util.Optional; -import java.util.Set; - -/** - * @author bratseth - * @deprecated use AllocatedHosts - */ -// TODO: Remove when no version older than 6.143 is in production anywhere -@Deprecated -@SuppressWarnings("unused") -public class ProvisionInfo extends AllocatedHosts { - - private static final String mappingKey = "mapping"; - private static final String hostSpecKey = "hostSpec"; - - private ProvisionInfo(Set<HostSpec> hosts) { - super(hosts); - } - - public static ProvisionInfo withHosts(Set<HostSpec> hosts) { - return new ProvisionInfo(hosts); - } - - public static ProvisionInfo fromJson(byte[] json, Optional<NodeFlavors> nodeFlavors) { - return fromSlime(SlimeUtils.jsonToSlime(json).get(), nodeFlavors); - } - - private static ProvisionInfo fromSlime(Inspector inspector, Optional<NodeFlavors> nodeFlavors) { - Inspector array = inspector.field(mappingKey); - Set<HostSpec> hosts = new LinkedHashSet<>(); - array.traverse(new ArrayTraverser() { - @Override - public void entry(int i, Inspector inspector) { - hosts.add(hostFromSlime(inspector.field(hostSpecKey), nodeFlavors)); - } - }); - return new ProvisionInfo(hosts); - } - -} diff --git a/config-proxy/pom.xml b/config-proxy/pom.xml index a266f68efe2..e5498aae5ec 100644 --- a/config-proxy/pom.xml +++ b/config-proxy/pom.xml @@ -63,6 +63,11 @@ <artifactId>filedistribution</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>org.bouncycastle</groupId> + <artifactId>bcpkix-jdk15on</artifactId> + <scope>compile</scope> + </dependency> </dependencies> <build> <plugins> @@ -82,24 +87,31 @@ </configuration> </plugin> <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-assembly-plugin</artifactId> - <configuration> - <descriptorRefs> - <descriptorRef>jar-with-dependencies</descriptorRef> - </descriptorRefs> - </configuration> - <executions> - <execution> - <id>make-assembly</id> - <phase>package</phase> - <goals> - <goal>single</goal> - </goals> - </execution> - </executions> - </plugin> - <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-shade-plugin</artifactId> + <configuration> + <finalName>${project.artifactId}-jar-with-dependencies</finalName> + <filters> + <filter> + <!-- Don't include signature files from bouncycastle in uber jar. --> + <artifact>*:*</artifact> + <excludes> + <exclude>META-INF/*.SF</exclude> + <exclude>META-INF/*.DSA</exclude> + <exclude>META-INF/*.RSA</exclude> + </excludes> + </filter> + </filters> + </configuration> + <executions> + <execution> + <phase>package</phase> + <goals> + <goal>shade</goal> + </goals> + </execution> + </executions> + </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-surefire-plugin</artifactId> <configuration> diff --git a/config-proxy/src/main/sh/vespa-config-ctl.sh b/config-proxy/src/main/sh/vespa-config-ctl.sh index a670e69cdbf..649eef951c0 100755 --- a/config-proxy/src/main/sh/vespa-config-ctl.sh +++ b/config-proxy/src/main/sh/vespa-config-ctl.sh @@ -103,6 +103,7 @@ export LD_LIBRARY_PATH="$VESPA_HOME/lib64" case $1 in start) + nohup sbin/vespa-retention-enforcer > ${LOGDIR}/vre-start.log 2>&1 </dev/null & configsources=`bin/vespa-print-default configservers_rpc` userargs=$vespa_base__jvmargs_configproxy if [ "$userargs" == "" ]; then diff --git a/config/src/main/java/com/yahoo/config/subscription/CfgConfigPayloadBuilder.java b/config/src/main/java/com/yahoo/config/subscription/CfgConfigPayloadBuilder.java index ad99afe3f36..90532344a58 100644 --- a/config/src/main/java/com/yahoo/config/subscription/CfgConfigPayloadBuilder.java +++ b/config/src/main/java/com/yahoo/config/subscription/CfgConfigPayloadBuilder.java @@ -155,7 +155,7 @@ public class CfgConfigPayloadBuilder { } private boolean isArray(String name) { - return name.endsWith("]"); + return name.endsWith("]") && !name.startsWith("["); } private boolean isMap(String name) { diff --git a/config/src/tests/failover/failover.cpp b/config/src/tests/failover/failover.cpp index 0f4a7e6bf6f..990ca761e7e 100644 --- a/config/src/tests/failover/failover.cpp +++ b/config/src/tests/failover/failover.cpp @@ -38,7 +38,7 @@ struct RPCServer : public FRT_Invokable { void init(FRT_Supervisor * s) { FRT_ReflectionBuilder rb(s); - rb.DefineMethod("config.v3.getConfig", requestTypes.c_str(), responseTypes.c_str(), true, + rb.DefineMethod("config.v3.getConfig", requestTypes.c_str(), responseTypes.c_str(), FRT_METHOD(RPCServer::getConfig), this); } diff --git a/config/src/tests/file_acquirer/file_acquirer_test.cpp b/config/src/tests/file_acquirer/file_acquirer_test.cpp index 0d2e2bf9144..0453c6ddbd0 100644 --- a/config/src/tests/file_acquirer/file_acquirer_test.cpp +++ b/config/src/tests/file_acquirer/file_acquirer_test.cpp @@ -11,7 +11,7 @@ struct ServerFixture : FRT_Invokable { vespalib::string spec; void init_rpc() { FRT_ReflectionBuilder rb(&orb); - rb.DefineMethod("waitFor", "s", "s", true, FRT_METHOD(ServerFixture::RPC_waitFor), this); + rb.DefineMethod("waitFor", "s", "s", FRT_METHOD(ServerFixture::RPC_waitFor), this); rb.MethodDesc("wait for and resolve file reference"); rb.ParamDesc("file_ref", "file reference to wait for and resolve"); rb.ReturnDesc("file_path", "actual path to the requested file"); diff --git a/configdefinitions/src/vespa/configserver.def b/configdefinitions/src/vespa/configserver.def index 70df8ce5164..c90709bf4dd 100644 --- a/configdefinitions/src/vespa/configserver.def +++ b/configdefinitions/src/vespa/configserver.def @@ -52,7 +52,7 @@ athenzDnsSuffix string default="" ztsUrl string default="" # Node admin -nodeAdminInContainer bool default=true +nodeAdminInContainer bool default=false # Maintainers maintainerIntervalMinutes int default=60 diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java index a8b4844ca43..6a55fb77933 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java @@ -10,6 +10,8 @@ import com.yahoo.config.FileReference; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationMetaData; import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.config.model.api.HostInfo; +import com.yahoo.config.model.api.ServiceInfo; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.HostFilter; @@ -36,6 +38,7 @@ import com.yahoo.vespa.config.server.configchange.RestartActions; import com.yahoo.vespa.config.server.deploy.DeployHandlerLogger; import com.yahoo.vespa.config.server.deploy.Deployment; import com.yahoo.vespa.config.server.http.CompressedApplicationInputStream; +import com.yahoo.vespa.config.server.http.LogRetriever; import com.yahoo.vespa.config.server.http.SimpleHttpFetcher; import com.yahoo.vespa.config.server.http.v2.PrepareResult; import com.yahoo.vespa.config.server.provision.HostProvisionerProvider; @@ -61,6 +64,7 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.Arrays; +import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -477,6 +481,14 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye return convergeChecker.servicesToCheck(getApplication(applicationId), uri, timeout); } + // ---------------- Logs ---------------------------------------------------------------- + + public HttpResponse getLogs(ApplicationId applicationId) { + String logServerHostName = getLogServerURI(applicationId); + LogRetriever logRetriever = new LogRetriever(); + return logRetriever.getLogs(logServerHostName); + } + // ---------------- Session operations ---------------------------------------------------------------- /** @@ -690,14 +702,39 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye } } + private String getLogServerURI(ApplicationId applicationId) { + Application application = getApplication(applicationId); + Collection<HostInfo> hostInfos = application.getModel().getHosts(); + + HostInfo logServerHostInfo = hostInfos.stream() + .filter(host -> host.getServices().stream() + .filter(serviceInfo -> + serviceInfo.getServiceType().equalsIgnoreCase("logserver")) + .count() > 0) + .findFirst().orElseThrow(() -> new IllegalArgumentException("Could not find HostInfo for LogServer")); + + ServiceInfo containerServiceInfo = logServerHostInfo.getServices().stream() + .filter(service -> service.getServiceType().equals("container")) + .findFirst().orElseThrow(() -> new IllegalArgumentException("No container running on logserver host")); + + int port = containerServiceInfo.getPorts().stream() + .filter(portInfo -> portInfo.getTags().stream() + .filter(tag -> tag.equalsIgnoreCase("http")).count() > 0) + .findFirst().orElseThrow(() -> new IllegalArgumentException("Could not find HTTP port")) + .getPort(); + + return "http://" + logServerHostInfo.getHostname() + ":" + port + "/logs"; + } + /** Returns version to use when deploying application in given environment */ - static Version decideVersion(ApplicationId application, Environment environment, Version targetVersion, boolean bootstrap) { - if (environment.isManuallyDeployed() && - !"hosted-vespa".equals(application.tenant().value()) && // Never change version of system applications - !bootstrap) { // Do not use current version when bootstrapping config server + static Version decideVersion(ApplicationId application, Environment environment, Version sessionVersion, boolean bootstrap) { + if ( environment.isManuallyDeployed() + && ! "hosted-vespa".equals(application.tenant().value()) // Never change version of system applications + && ! application.instance().isTester() // Never upgrade tester containers + && ! bootstrap) { // Do not use current version when bootstrapping config server return Vtag.currentVersion; } - return targetVersion; + return sessionVersion; } public Slime createDeployLog() { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java index 91cb5891e84..0f507624188 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java @@ -23,7 +23,9 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -66,8 +68,9 @@ public class ConfigConvergenceChecker extends AbstractComponent { .filter(service -> serviceTypesToCheck.contains(service.getServiceType())) .forEach(service -> getStatePort(service).ifPresent(port -> servicesToCheck.add(service)))); - long currentGeneration = getServiceGeneration(servicesToCheck, timeoutPerService); - return new ServiceListResponse(200, servicesToCheck, requestUrl, application.getApplicationGeneration(), + Map<ServiceInfo, Long> currentGenerations = getServiceGenerations(servicesToCheck, timeoutPerService); + long currentGeneration = currentGenerations.values().stream().mapToLong(Long::longValue).min().orElse(-1); + return new ServiceListResponse(200, currentGenerations, requestUrl, application.getApplicationGeneration(), currentGeneration); } @@ -100,23 +103,21 @@ public class ConfigConvergenceChecker extends AbstractComponent { } /** Get service generation for a list of services. Returns the minimum generation of all services */ - private long getServiceGeneration(List<ServiceInfo> services, Duration timeout) { - List<URI> serviceUris = services.stream() - .map(s -> "http://" + s.getHostName() + ":" + getStatePort(s).get()) - .map(URI::create) - .collect(Collectors.toList()); - long generation = -1; - for (URI uri : serviceUris) { - try { - long serviceGeneration = getServiceGeneration(uri, timeout); - if (generation == -1 || serviceGeneration < generation) { - generation = serviceGeneration; - } - } catch (ProcessingException e) { // Cannot connect to service to determine service generation - return -1; - } - } - return generation; + private Map<ServiceInfo, Long> getServiceGenerations(List<ServiceInfo> services, Duration timeout) { + return services.stream() + .collect(Collectors.toMap(service -> service, + service -> { + try { + return getServiceGeneration(URI.create("http://" + service.getHostName() + + ":" + getStatePort(service).get()), timeout); + } + catch (ProcessingException e) { // Cannot connect to service to determine service generation + return -1L; + } + }, + (v1, v2) -> { throw new IllegalStateException("Duplicate keys for values '" + v1 + "' and '" + v2 + "'."); }, + LinkedHashMap::new + )); } /** Get service generation of service at given URL */ @@ -160,7 +161,7 @@ public class ConfigConvergenceChecker extends AbstractComponent { } private static long generationFromContainerState(JsonNode state) { - return state.get("config").get("generation").asLong(); + return state.get("config").get("generation").asLong(-1); } private static StateApi createStateApi(Client client, URI uri) { @@ -171,19 +172,20 @@ public class ConfigConvergenceChecker extends AbstractComponent { private static class ServiceListResponse extends JSONResponse { // Pre-condition: servicesToCheck has a state port - private ServiceListResponse(int status, List<ServiceInfo> servicesToCheck, URI uri, long wantedGeneration, + private ServiceListResponse(int status, Map<ServiceInfo, Long> servicesToCheck, URI uri, long wantedGeneration, long currentGeneration) { super(status); Cursor serviceArray = object.setArray("services"); - for (ServiceInfo s : servicesToCheck) { - Cursor service = serviceArray.addObject(); - String hostName = s.getHostName(); - int statePort = getStatePort(s).get(); - service.setString("host", hostName); - service.setLong("port", statePort); - service.setString("type", s.getServiceType()); - service.setString("url", uri.toString() + "/" + hostName + ":" + statePort); - } + servicesToCheck.forEach((service, generation) -> { + Cursor serviceObject = serviceArray.addObject(); + String hostName = service.getHostName(); + int statePort = getStatePort(service).get(); + serviceObject.setString("host", hostName); + serviceObject.setLong("port", statePort); + serviceObject.setString("type", service.getServiceType()); + serviceObject.setString("url", uri.toString() + "/" + hostName + ":" + statePort); + serviceObject.setLong("currentGeneration", generation); + }); object.setString("url", uri.toString()); object.setLong("currentGeneration", currentGeneration); object.setLong("wantedGeneration", wantedGeneration); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/LogRetriever.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/LogRetriever.java new file mode 100644 index 00000000000..dd60d158313 --- /dev/null +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/LogRetriever.java @@ -0,0 +1,41 @@ +package com.yahoo.vespa.config.server.http; + +import com.yahoo.container.jdisc.HttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.util.EntityUtils; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.logging.Level; +import java.util.logging.Logger; + + +public class LogRetriever { + + private final static Logger log = Logger.getLogger(LogRetriever.class.getName()); + + public HttpResponse getLogs(String logServerHostname) { + HttpGet get = new HttpGet(logServerHostname); + try (CloseableHttpClient httpClient = HttpClientBuilder.create().build()) { + org.apache.http.HttpResponse response = httpClient.execute(get); + String responseBody = EntityUtils.toString(response.getEntity(), "UTF-8"); + return new HttpResponse(response.getStatusLine().getStatusCode()) { + @Override + public void render(OutputStream outputStream) throws IOException { + if (response.getEntity() != null ) outputStream.write(responseBody.getBytes()); + } + }; + } catch (IOException e) { + log.log(Level.WARNING, "Failed to retrieve logs from log server", e); + return new HttpResponse(404) { + @Override + public void render(OutputStream outputStream) throws IOException { + outputStream.write(e.toString().getBytes()); + } + }; + } + + } +} diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java index 2004ab95144..b65cb370f93 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java @@ -96,6 +96,10 @@ public class ApplicationHandler extends HttpHandler { return applicationRepository.filedistributionStatus(applicationId, timeout); } + if (isLogRequest(request)) { + return applicationRepository.getLogs(applicationId); + } + return new GetApplicationResponse(Response.Status.OK, applicationRepository.getApplicationGeneration(applicationId)); } @@ -140,7 +144,13 @@ public class ApplicationHandler extends HttpHandler { "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*/serviceconverge/*", "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*/clustercontroller/*/status/*", "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*", - "http://*/application/v2/tenant/*/application/*"); + "http://*/application/v2/tenant/*/application/*", + "http://*/application/v2/tenant/*/application/*/logs"); + } + + private static boolean isLogRequest(HttpRequest request) { + return getBindingMatch(request).groupCount() == 4 && + request.getUri().getPath().endsWith("/logs"); } private static boolean isServiceConvergeListRequest(HttpRequest request) { diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml index 8a99869e69a..60dd7b0cea2 100644 --- a/configserver/src/main/resources/configserver-app/services.xml +++ b/configserver/src/main/resources/configserver-app/services.xml @@ -147,6 +147,8 @@ <binding>https://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*</binding> <binding>http://*/application/v2/tenant/*/application/*</binding> <binding>https://*/application/v2/tenant/*/application/*</binding> + <binding>http://*/application/v2/tenant/*/application/*/logs</binding> + <binding>https://*/application/v2/tenant/*/application/*/logs</binding> </handler> <handler id='com.yahoo.vespa.config.server.http.v2.HttpGetConfigHandler' bundle='configserver'> <binding>http://*/config/v2/tenant/*/application/*/*</binding> diff --git a/configserver/src/test/apps/app-logserver-with-container/hosts.xml b/configserver/src/test/apps/app-logserver-with-container/hosts.xml new file mode 100644 index 00000000000..d5a51f050fd --- /dev/null +++ b/configserver/src/test/apps/app-logserver-with-container/hosts.xml @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<hosts> + <host name="localhost"> + <alias>node1</alias> + </host> +</hosts> + diff --git a/configserver/src/test/apps/app-logserver-with-container/services.xml b/configserver/src/test/apps/app-logserver-with-container/services.xml new file mode 100644 index 00000000000..3b88fc3879d --- /dev/null +++ b/configserver/src/test/apps/app-logserver-with-container/services.xml @@ -0,0 +1,18 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <admin version="2.0"> + <adminserver hostalias="node1"/> + <logserver hostalias="node1"/> + </admin> + + + + <container version="1.0"> + <nodes> + <node hostalias="node1" /> + </nodes> + </container> + +</services> diff --git a/configserver/src/test/apps/app/services.xml b/configserver/src/test/apps/app/services.xml index 6cc30b8b6ec..457a3fad397 100644 --- a/configserver/src/test/apps/app/services.xml +++ b/configserver/src/test/apps/app/services.xml @@ -4,6 +4,7 @@ <admin version="2.0"> <adminserver hostalias="node1"/> + <logserver hostalias="node1" /> </admin> <content version="1.0"> diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java index d9a653a1dc2..120119f35bb 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java @@ -1,6 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.client.WireMock; import com.google.common.io.Files; import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.component.Version; @@ -13,6 +15,7 @@ import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.Provisioner; import com.yahoo.config.provision.TenantName; +import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.io.IOUtils; import com.yahoo.test.ManualClock; import com.yahoo.text.Utf8; @@ -41,6 +44,11 @@ import java.util.Collections; import java.util.Optional; import java.util.Set; +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -58,6 +66,7 @@ public class ApplicationRepositoryTest { private final static File testApp = new File("src/test/apps/app"); private final static File testAppJdiscOnly = new File("src/test/apps/app-jdisc-only"); private final static File testAppJdiscOnlyRestart = new File("src/test/apps/app-jdisc-only-restart"); + private final static File testAppLogServerWithContainer = new File("src/test/apps/app-logserver-with-container"); private final static TenantName tenant1 = TenantName.from("test1"); private final static TenantName tenant2 = TenantName.from("test2"); @@ -109,6 +118,27 @@ public class ApplicationRepositoryTest { } @Test + public void getLogs(){ + WireMockServer wireMock = new WireMockServer(wireMockConfig().port(8080)); + wireMock.start(); + WireMock.configureFor("localhost", wireMock.port()); + stubFor(get(urlEqualTo("/logs")) + .willReturn(aResponse() + .withStatus(200))); + wireMock.start(); + deployApp(testAppLogServerWithContainer); + HttpResponse response = applicationRepository.getLogs(applicationId()); + assertEquals(response.getStatus(),200); + wireMock.stop(); + } + + @Test(expected = IllegalArgumentException.class) + public void getLogsNoContainerOnLogServerHostShouldThrowException() { + deployApp(testApp); + applicationRepository.getLogs(applicationId()); + } + + @Test public void deleteUnusedTenants() { // Set clock to epoch plus hour, as mock curator will always return epoch as creation time Instant now = ManualClock.at("1970-01-01T01:00:00"); @@ -135,19 +165,25 @@ public class ApplicationRepositoryTest { public void decideVersion() { ApplicationId regularApp = ApplicationId.from("tenant1", "application1", "default"); ApplicationId systemApp = ApplicationId.from("hosted-vespa", "routing", "default"); - Version targetVersion = Version.fromString("5.0"); + ApplicationId testerApp = ApplicationId.from("tenant1", "application1", "default-t"); + Version sessionVersion = Version.fromString("5.0"); + + // Always use session version for system application + assertEquals(sessionVersion, ApplicationRepository.decideVersion(systemApp, Environment.prod, sessionVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(systemApp, Environment.dev, sessionVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(systemApp, Environment.perf, sessionVersion, false)); - // Always use target for system application - assertEquals(targetVersion, ApplicationRepository.decideVersion(systemApp, Environment.prod, targetVersion, false)); - assertEquals(targetVersion, ApplicationRepository.decideVersion(systemApp, Environment.dev, targetVersion, false)); - assertEquals(targetVersion, ApplicationRepository.decideVersion(systemApp, Environment.perf, targetVersion, false)); + // Always use session version for tester application + assertEquals(sessionVersion, ApplicationRepository.decideVersion(testerApp, Environment.prod, sessionVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(testerApp, Environment.dev, sessionVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(testerApp, Environment.perf, sessionVersion, false)); // Target for regular application depends on environment - assertEquals(targetVersion, ApplicationRepository.decideVersion(regularApp, Environment.prod, targetVersion, false)); - assertEquals(Vtag.currentVersion, ApplicationRepository.decideVersion(regularApp, Environment.dev, targetVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(regularApp, Environment.prod, sessionVersion, false)); + assertEquals(Vtag.currentVersion, ApplicationRepository.decideVersion(regularApp, Environment.dev, sessionVersion, false)); // If bootstrap, version should be target version - assertEquals(targetVersion, ApplicationRepository.decideVersion(regularApp, Environment.dev, targetVersion, true)); - assertEquals(Vtag.currentVersion, ApplicationRepository.decideVersion(regularApp, Environment.perf, targetVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(regularApp, Environment.dev, sessionVersion, true)); + assertEquals(Vtag.currentVersion, ApplicationRepository.decideVersion(regularApp, Environment.perf, sessionVersion, false)); } @Test diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/application/ConfigConvergenceCheckerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/application/ConfigConvergenceCheckerTest.java index 487e96f17b2..fdea0fced67 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/application/ConfigConvergenceCheckerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/application/ConfigConvergenceCheckerTest.java @@ -2,7 +2,6 @@ package com.yahoo.vespa.config.server.application; import com.github.tomakehurst.wiremock.junit.WireMockRule; -import com.github.tomakehurst.wiremock.stubbing.Scenario; import com.yahoo.config.model.api.Model; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ApplicationName; @@ -47,16 +46,20 @@ public class ConfigConvergenceCheckerTest { private Application application; private ConfigConvergenceChecker checker; private URI service; + private URI service2; @Rule public TemporaryFolder folder = new TemporaryFolder(); @Rule public final WireMockRule wireMock = new WireMockRule(options().dynamicPort(), true); + @Rule + public final WireMockRule wireMock2 = new WireMockRule(options().dynamicPort(), true); @Before public void setup() { service = testServer(); + service2 = testServer(wireMock2); Model mockModel = MockModel.createContainer(service.getHost(), service.getPort()); application = new Application(mockModel, new ServerCache(), @@ -115,7 +118,8 @@ public class ConfigConvergenceCheckerTest { " \"host\": \"" + serviceUrl.getHost() + "\",\n" + " \"port\": " + serviceUrl.getPort() + ",\n" + " \"type\": \"container\",\n" + - " \"url\": \"" + serviceUrl.toString() + "\"\n" + + " \"url\": \"" + serviceUrl.toString() + "\",\n" + + " \"currentGeneration\":" + 3 + "\n" + " }\n" + " ],\n" + " \"url\": \"" + requestUrl.toString() + "\",\n" + @@ -130,49 +134,45 @@ public class ConfigConvergenceCheckerTest { { // Model with two hosts on different generations MockModel model = new MockModel(Arrays.asList( - // Reuse hostname and port to avoid the need for two WireMock servers MockModel.createContainerHost(service.getHost(), service.getPort()), - MockModel.createContainerHost(service.getHost(), service.getPort())) + MockModel.createContainerHost(service2.getHost(), service2.getPort())) ); Application application = new Application(model, new ServerCache(), 4, false, Version.fromIntValues(0, 0, 0), MetricUpdater.createTestUpdater(), appId); - String host2 = "host2"; - wireMock.stubFor(get(urlEqualTo("/state/v1/config")).inScenario("config request") - .whenScenarioStateIs(Scenario.STARTED) - .willReturn(okJson("{\"config\":{\"generation\":4}}")) - .willSetStateTo(host2)); - wireMock.stubFor(get(urlEqualTo("/state/v1/config")).inScenario("config request") - .whenScenarioStateIs(host2) - .willReturn(okJson("{\"config\":{\"generation\":3}}"))); + wireMock.stubFor(get(urlEqualTo("/state/v1/config")).willReturn(okJson("{\"config\":{\"generation\":4}}"))); + wireMock2.stubFor(get(urlEqualTo("/state/v1/config")).willReturn(okJson("{\"config\":{\"generation\":3}}"))); URI requestUrl = testServer().resolve("/serviceconverge"); URI serviceUrl = testServer().resolve("/serviceconverge/" + hostAndPort(service)); + URI serviceUrl2 = testServer().resolve("/serviceconverge/" + hostAndPort(service2)); HttpResponse response = checker.servicesToCheck(application, requestUrl, Duration.ofSeconds(5)); assertResponse("{\n" + - " \"services\": [\n" + - " {\n" + - " \"host\": \"" + service.getHost() + "\",\n" + - " \"port\": " + service.getPort() + ",\n" + - " \"type\": \"container\",\n" + - " \"url\": \"" + serviceUrl.toString() + "\"\n" + - " },\n" + - " {\n" + - " \"host\": \"" + service.getHost() + "\",\n" + - " \"port\": " + service.getPort() + ",\n" + - " \"type\": \"container\",\n" + - " \"url\": \"" + serviceUrl.toString() + "\"\n" + - " }\n" + - " ],\n" + - " \"url\": \"" + requestUrl.toString() + "\",\n" + - " \"currentGeneration\": 3,\n" + - " \"wantedGeneration\": 4,\n" + - " \"converged\": false\n" + - "}", - 200, - response); + " \"services\": [\n" + + " {\n" + + " \"host\": \"" + service.getHost() + "\",\n" + + " \"port\": " + service.getPort() + ",\n" + + " \"type\": \"container\",\n" + + " \"url\": \"" + serviceUrl.toString() + "\",\n" + + " \"currentGeneration\":" + 4 + "\n" + + " },\n" + + " {\n" + + " \"host\": \"" + service2.getHost() + "\",\n" + + " \"port\": " + service2.getPort() + ",\n" + + " \"type\": \"container\",\n" + + " \"url\": \"" + serviceUrl2.toString() + "\",\n" + + " \"currentGeneration\":" + 3 + "\n" + + " }\n" + + " ],\n" + + " \"url\": \"" + requestUrl.toString() + "\",\n" + + " \"currentGeneration\": 3,\n" + + " \"wantedGeneration\": 4,\n" + + " \"converged\": false\n" + + "}", + 200, + response); } } @@ -193,6 +193,10 @@ public class ConfigConvergenceCheckerTest { } private URI testServer() { + return testServer(wireMock); + } + + private URI testServer(WireMockRule wireMock) { return URI.create("http://127.0.0.1:" + wireMock.port()); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/LogRetrieverTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/LogRetrieverTest.java new file mode 100644 index 00000000000..eb819053c05 --- /dev/null +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/LogRetrieverTest.java @@ -0,0 +1,54 @@ +package com.yahoo.vespa.config.server.http; + +import com.github.tomakehurst.wiremock.junit.WireMockRule; +import com.yahoo.container.jdisc.HttpResponse; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; + + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.okJson; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; +import static org.junit.Assert.assertEquals; + +public class LogRetrieverTest { + + private String logServerHostName = "http://localhost:8080/"; + private LogRetriever logRetriever; + + @Rule + public final WireMockRule wireMock = new WireMockRule(options().port(8080), true); + + @Before + public void setup() { + logRetriever = new LogRetriever(); + } + + @Test + public void testThatLogHandlerPropagatesResponseBody() throws IOException { + String expectedBody = "{logs-json}"; + stubFor(get(urlEqualTo("/")).willReturn(okJson(expectedBody))); + HttpResponse response = logRetriever.getLogs(logServerHostName); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + response.render(byteArrayOutputStream); + assertEquals(expectedBody, byteArrayOutputStream.toString()); + assertEquals(200, response.getStatus()); + } + + @Test + public void testThatNotFoundLogServerReturns404() throws IOException { + stubFor(get(urlEqualTo("/")).willReturn(aResponse().withStatus(200))); + HttpResponse response = logRetriever.getLogs("http://wrong-host:8080/"); + assertEquals(404, response.getStatus()); + } + + + +}
\ No newline at end of file diff --git a/container-accesslogging/pom.xml b/container-accesslogging/pom.xml index 0d7b134c58c..b2c9bd7db8b 100644 --- a/container-accesslogging/pom.xml +++ b/container-accesslogging/pom.xml @@ -87,17 +87,6 @@ <artifactId>bundle-plugin</artifactId> <extensions>true</extensions> </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-compiler-plugin</artifactId> - <configuration> - <compilerArgs> - <arg>-Xlint:all</arg> - <arg>-Xlint:-serial</arg> - <arg>-Werror</arg> - </compilerArgs> - </configuration> - </plugin> </plugins> <outputDirectory>${buildOutputDirectory}</outputDirectory> </build> diff --git a/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java b/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java index adadd0b1414..595bd99a759 100644 --- a/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java +++ b/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java @@ -175,7 +175,7 @@ public class JSONFormatter { duration = new BigDecimal(0xffffffff); } - return duration.setScale(3, BigDecimal.ROUND_HALF_UP); + return duration.setScale(3, RoundingMode.HALF_UP); } private static String getNormalizedURI(String rawPath, String rawQuery) { diff --git a/container-accesslogging/src/main/java/com/yahoo/container/logging/LogFileHandler.java b/container-accesslogging/src/main/java/com/yahoo/container/logging/LogFileHandler.java index c7e2a777695..d729b092670 100644 --- a/container-accesslogging/src/main/java/com/yahoo/container/logging/LogFileHandler.java +++ b/container-accesslogging/src/main/java/com/yahoo/container/logging/LogFileHandler.java @@ -2,6 +2,7 @@ package com.yahoo.container.logging; import com.yahoo.container.core.AccessLogConfig; +import com.yahoo.log.LogFileDb; import java.io.File; import java.io.FileOutputStream; @@ -250,6 +251,7 @@ public class LogFileHandler extends StreamHandler { FileOutputStream os = new FileOutputStream(fileName, true); // append mode, for safety super.setOutputStream(os); currentOutputStream = os; + if (! useSequenceNameScheme) LogFileDb.nowLoggingTo(fileName); } catch (IOException e) { throw new RuntimeException("Couldn't open log file '" + fileName + "'", e); @@ -310,7 +312,9 @@ public class LogFileHandler extends StreamHandler { if (thisN>largestN) largestN=thisN; } - file.renameTo(new File(dir,file.getName() + "." + (largestN + 1))); + File newFn = new File(dir, file.getName() + "." + (largestN + 1)); + LogFileDb.nowLoggingTo(newFn.getAbsolutePath()); + file.renameTo(newFn); } /** diff --git a/container-accesslogging/src/main/resources/configdefinitions/access-log.def b/container-accesslogging/src/main/resources/configdefinitions/access-log.def index 276128e0405..9df9299ae19 100644 --- a/container-accesslogging/src/main/resources/configdefinitions/access-log.def +++ b/container-accesslogging/src/main/resources/configdefinitions/access-log.def @@ -1,11 +1,16 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. namespace=container.core - # File name patterns supporting the expected time variables, e.g. ".%Y%m%d%H%M%S" fileHandler.pattern string + +# When should rotation happen, in minutes after midnight +# Does this really need to be configurable? +# Could just configure "every N minutes" instead fileHandler.rotation string default="0 60 ..." +# TODO remove in Vespa 7, always use DATE +# # Defines how file rotation is done. There are two options: # # DATE: @@ -27,4 +32,5 @@ fileHandler.rotateScheme enum {DATE, SEQUENCE} default=DATE fileHandler.symlink string default="" # compress the previous access log after rotation +# TODO change to "true" for Vespa 7 fileHandler.compressOnRotation bool default=false diff --git a/container-core/src/main/java/com/yahoo/container/Container.java b/container-core/src/main/java/com/yahoo/container/Container.java index efe7c58563c..7e5ea7bd948 100755 --- a/container-core/src/main/java/com/yahoo/container/Container.java +++ b/container-core/src/main/java/com/yahoo/container/Container.java @@ -145,7 +145,7 @@ public class Container { /** * Only for internal use. */ - public void setCustomFileAcquirer(final FileAcquirer fileAcquirer) { + public void setCustomFileAcquirer(FileAcquirer fileAcquirer) { if (this.fileAcquirer != null) { throw new RuntimeException("Can't change file acquirer. Is " + this.fileAcquirer + " attempted to set to " + fileAcquirer); @@ -155,7 +155,7 @@ public class Container { setPathAcquirer(fileAcquirer); } - private static void setPathAcquirer(final FileAcquirer fileAcquirer) { + private static void setPathAcquirer(FileAcquirer fileAcquirer) { ConfigTransformer.setPathAcquirer(fileReference -> { try { return fileAcquirer.waitFor(fileReference, 15, TimeUnit.MINUTES).toPath(); diff --git a/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java b/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java index c94dc30fd6f..ee12c7d4c9f 100644 --- a/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java +++ b/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java @@ -3,12 +3,13 @@ package com.yahoo.container.core; /** * @author gjoranv - * @since 5.46 */ public interface BundleLoaderProperties { + // TODO: This should be removed. The prefix is used to separate the bundles in BundlesConfig // into those that are transferred with filedistribution and those that are preinstalled // on disk. Instead, the model should have put them in two different configs. I.e. create a new // config 'preinstalled-bundles.def'. - public static final String DISK_BUNDLE_PREFIX = "file:"; + String DISK_BUNDLE_PREFIX = "file:"; + } diff --git a/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java b/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java index eceb41f9739..557f331395b 100644 --- a/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java +++ b/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java @@ -41,8 +41,7 @@ public class BundleLoader { initialBundles = Arrays.asList(osgi.getBundles()); } - private List<Bundle> obtainBundles(FileReference reference, FileAcquirer fileAcquirer) - throws InterruptedException { + private List<Bundle> obtainBundles(FileReference reference, FileAcquirer fileAcquirer) throws InterruptedException { File file = fileAcquirer.waitFor(reference, 7, TimeUnit.DAYS); return osgi.install(file.getAbsolutePath()); } @@ -95,7 +94,7 @@ public class BundleLoader { log.info("Installing bundle from disk with reference '" + reference.value() + "'"); File file = new File(referenceFileName); - if (!file.exists()) { + if ( ! file.exists()) { throw new IllegalArgumentException("Reference '" + reference.value() + "' not found on disk."); } diff --git a/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java b/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java new file mode 100644 index 00000000000..4183b642af1 --- /dev/null +++ b/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java @@ -0,0 +1,45 @@ +package com.yahoo.container.handler; + +import com.google.inject.Inject; +import com.yahoo.container.jdisc.HttpRequest; +import com.yahoo.container.jdisc.HttpResponse; +import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; +import org.json.JSONException; +import org.json.JSONObject; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.util.concurrent.Executor; + +public class LogHandler extends ThreadedHttpRequestHandler { + + private static final String LOG_DIRECTORY = "/home/y/logs/vespa/"; + + @Inject + public LogHandler(Executor executor) { + super(executor); + } + + @Override + public HttpResponse handle(HttpRequest request) { + JSONObject logJson; + + try { + logJson = LogReader.readLogs(LOG_DIRECTORY); + } catch (IOException | JSONException e) { + return new HttpResponse(404) { + @Override + public void render(OutputStream outputStream) {} + }; + } + return new HttpResponse(200) { + @Override + public void render(OutputStream outputStream) throws IOException { + OutputStreamWriter outputStreamWriter = new OutputStreamWriter(outputStream); + outputStreamWriter.write(logJson.toString()); + outputStreamWriter.close(); + } + }; + } +} diff --git a/container-core/src/main/java/com/yahoo/container/handler/LogReader.java b/container-core/src/main/java/com/yahoo/container/handler/LogReader.java new file mode 100644 index 00000000000..eb00446dd0e --- /dev/null +++ b/container-core/src/main/java/com/yahoo/container/handler/LogReader.java @@ -0,0 +1,32 @@ +package com.yahoo.container.handler; + +import org.json.JSONException; +import org.json.JSONObject; + +import javax.xml.bind.DatatypeConverter; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; + +public class LogReader { + + protected static JSONObject readLogs(String logDirectory) throws IOException, JSONException { + JSONObject json = new JSONObject(); + File root = new File(logDirectory); + traverse_folder(root, json); + return json; + } + + private static void traverse_folder(File root, JSONObject json) throws IOException, JSONException { + for(File child : root.listFiles()) { + JSONObject childJson = new JSONObject(); + if(child.isFile()) { + json.put(child.getName(), DatatypeConverter.printBase64Binary(Files.readAllBytes(child.toPath()))); + } + else { + json.put(child.getName(), childJson); + traverse_folder(child, childJson); + } + } + } +} diff --git a/container-core/src/test/java/com/yahoo/container/handler/LogReaderTest.java b/container-core/src/test/java/com/yahoo/container/handler/LogReaderTest.java new file mode 100644 index 00000000000..e5302ee43ee --- /dev/null +++ b/container-core/src/test/java/com/yahoo/container/handler/LogReaderTest.java @@ -0,0 +1,28 @@ +package com.yahoo.container.handler; + +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; + +import static org.junit.Assert.*; + +public class LogReaderTest { + + ByteArrayOutputStream outputStream; + + @Before + public void setup() { + outputStream = new ByteArrayOutputStream(); + } + + @Test + public void testThatFilesAreWrittenCorrectlyToOutputStream() throws Exception{ + String logDirectory = "src/test/resources/logfolder/"; + JSONObject json = LogReader.readLogs(logDirectory); + String expected = "{\"subfolder\":{\"log2.log\":\"VGhpcyBpcyBhbm90aGVyIGxvZyBmaWxl\"},\"log1.log\":\"VGhpcyBpcyBvbmUgbG9nIGZpbGU=\"}"; + String actual = json.toString(); + assertEquals(expected, actual); + } +}
\ No newline at end of file diff --git a/container-core/src/test/resources/logfolder/log1.log b/container-core/src/test/resources/logfolder/log1.log new file mode 100644 index 00000000000..bb85d5a4950 --- /dev/null +++ b/container-core/src/test/resources/logfolder/log1.log @@ -0,0 +1 @@ +This is one log file
\ No newline at end of file diff --git a/container-core/src/test/resources/logfolder/subfolder/log2.log b/container-core/src/test/resources/logfolder/subfolder/log2.log new file mode 100644 index 00000000000..aee6eaca2e8 --- /dev/null +++ b/container-core/src/test/resources/logfolder/subfolder/log2.log @@ -0,0 +1 @@ +This is another log file
\ No newline at end of file diff --git a/container-dependency-versions/pom.xml b/container-dependency-versions/pom.xml index ccb8a9c311c..259fcfb8de7 100644 --- a/container-dependency-versions/pom.xml +++ b/container-dependency-versions/pom.xml @@ -466,7 +466,7 @@ <guava.version>18.0</guava.version> <guice.version>3.0</guice.version> <jaxb.version>2.3.0</jaxb.version> - <jetty.version>9.4.10.v20180503</jetty.version> + <jetty.version>9.4.12.v20180830</jetty.version> <slf4j.version>1.7.5</slf4j.version> <!-- These must be kept in sync with version used by current jersey2.version. --> diff --git a/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentGraph.java b/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentGraph.java index 463de0c089a..76ca94c9286 100644 --- a/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentGraph.java +++ b/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentGraph.java @@ -77,7 +77,6 @@ public class ComponentGraph { private Optional<Node> lookupGlobalComponent(Key<?> key) { if (!(key.getTypeLiteral().getType() instanceof Class)) { - throw new RuntimeException("Type not supported " + key.getTypeLiteral()); } Class<?> clazz = key.getTypeLiteral().getRawType(); diff --git a/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/Exceptions.java b/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/Exceptions.java index d84d771fef6..e8c527aeaef 100644 --- a/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/Exceptions.java +++ b/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/Exceptions.java @@ -3,6 +3,7 @@ package com.yahoo.container.di.componentgraph.core; import java.util.Arrays; class Exceptions { + static <E extends Throwable> E removeStackTrace(E exception) { if (preserveStackTrace()) { return exception; diff --git a/container-disc/CMakeLists.txt b/container-disc/CMakeLists.txt index 1b661020166..92f5b303d41 100644 --- a/container-disc/CMakeLists.txt +++ b/container-disc/CMakeLists.txt @@ -6,7 +6,6 @@ vespa_install_script(src/main/sh/vespa-start-container-daemon.sh vespa-start-con install_config_definition(src/main/resources/configdefinitions/container.jdisc.config.http-server.def) install_config_definition(src/main/resources/configdefinitions/jdisc-bindings.def container.jdisc.jdisc-bindings.def) install_config_definition(src/main/resources/configdefinitions/jersey-connection.def container.config.jersey.jersey-connection.def) -install_config_definition(src/main/resources/configdefinitions/jersey-init.def container.config.jersey.jersey-init.def) install_config_definition(src/main/resources/configdefinitions/jersey-web-app-pool.def container.config.jersey.jersey-web-app-pool.def) install_config_definition(src/main/resources/configdefinitions/metric-defaults.def container.jdisc.config.metric-defaults.def) install_config_definition(src/main/resources/configdefinitions/score-board.def jdisc.metrics.yamasconsumer.cloud.score-board.def) diff --git a/container-disc/src/main/resources/configdefinitions/jersey-init.def b/container-disc/src/main/resources/configdefinitions/jersey-init.def deleted file mode 100644 index 95ec9f23906..00000000000 --- a/container-disc/src/main/resources/configdefinitions/jersey-init.def +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -## Do NOT move this file to the container-jersey module. If system bundles -## like config-model import packages from container-jersey, new class -## loaders for these bundles will be created after reconfig. -namespace=container.config.jersey - -# Controlled by the config framework, do not set this from services.xml! -jerseyMapping string diff --git a/container-disc/src/main/sh/vespa-start-container-daemon.sh b/container-disc/src/main/sh/vespa-start-container-daemon.sh index e6219ab0467..21c9dc28022 100755 --- a/container-disc/src/main/sh/vespa-start-container-daemon.sh +++ b/container-disc/src/main/sh/vespa-start-container-daemon.sh @@ -52,7 +52,7 @@ getconfig() { qrstartcfg="`cat ${config_dir}/qr-start.cfg`" ;; *) - qrstartcfg="`$VESPA_HOME/bin/vespa-get-config -w 10 -n search.config.qr-start -i ${VESPA_CONFIG_ID}`" + qrstartcfg="`$VESPA_HOME/bin/vespa-get-config -l -w 10 -n search.config.qr-start -i ${VESPA_CONFIG_ID}`" ;; esac cmds=`echo "$qrstartcfg" | perl -ne 's/^(\w+)\.(\w+) (.*)/$1_$2=$3/ && print'` diff --git a/container-search/src/main/java/com/yahoo/fs4/mplex/Backend.java b/container-search/src/main/java/com/yahoo/fs4/mplex/Backend.java index 3eabc3c6a6c..2a90e746378 100644 --- a/container-search/src/main/java/com/yahoo/fs4/mplex/Backend.java +++ b/container-search/src/main/java/com/yahoo/fs4/mplex/Backend.java @@ -61,12 +61,11 @@ public class Backend implements ConnectionFactory { private final ConnectionPool connectionPool; private final PacketDumper packetDumper; private final AtomicInteger connectionCount = new AtomicInteger(0); - private final Optional<Integer> distributionKey; /** * For unit testing. do not use */ - protected Backend(Optional<Integer> distributionKey) { + protected Backend() { listeners = null; host = null; port = 0; @@ -74,15 +73,13 @@ public class Backend implements ConnectionFactory { packetDumper = null; address = null; connectionPool = new ConnectionPool(); - this.distributionKey = distributionKey; } public Backend(String host, int port, String serverDiscriminator, ListenerPool listenerPool, - ConnectionPool connectionPool, - Optional<Integer> distributionKey) { + ConnectionPool connectionPool) { String fileNamePattern = "qrs." + serverDiscriminator + '.' + host + ":" + port + ".%s" + ".dump"; packetDumper = new PacketDumper(new File(Defaults.getDefaults().underVespaHome("logs/vespa/qrs/")), fileNamePattern); @@ -92,7 +89,6 @@ public class Backend implements ConnectionFactory { this.port = port; address = new InetSocketAddress(host, port); this.connectionPool = connectionPool; - this.distributionKey = distributionKey; } private void logWarning(String attemptDescription, Exception e) { @@ -103,9 +99,6 @@ public class Backend implements ConnectionFactory { log.log(Level.INFO, "Exception on " + attemptDescription + " '" + host + ":" + port + "': " + Exceptions.toMessageString(e)); } - /** Returns the distribution key of the content node this represents, or empty if it is a dispatch node */ - public Optional<Integer> distributionKey() { return distributionKey; } - // ============================================================ // ==== connection pool stuff // ============================================================ diff --git a/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java b/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java index 237b0cdb8e2..de4d9c9fe8b 100644 --- a/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java +++ b/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java @@ -4,7 +4,6 @@ package com.yahoo.fs4.mplex; import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.Optional; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -67,9 +66,6 @@ public class FS4Channel { return channelId; } - /** Returns the distribution key of the content node this represents, or empty if it is a dispatch node */ - public Optional<Integer> distributionKey() { return backend == null ? Optional.empty() : backend.distributionKey(); } - /** * Closes the channel */ diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4ResourcePool.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4ResourcePool.java index e933f4857b3..51b3146a609 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4ResourcePool.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4ResourcePool.java @@ -9,12 +9,9 @@ import com.yahoo.container.search.Fs4Config; import com.yahoo.fs4.mplex.Backend; import com.yahoo.fs4.mplex.ConnectionPool; import com.yahoo.fs4.mplex.ListenerPool; -import com.yahoo.io.Connection; -import java.io.IOException; import java.util.HashMap; import java.util.Map; -import java.util.Optional; import java.util.Timer; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -61,14 +58,11 @@ public class FS4ResourcePool extends AbstractComponent { } public Backend getBackend(String host, int port) { - return getBackend(host, port, Optional.empty()); - } - public Backend getBackend(String host, int port, Optional<Integer> distributionKey) { String key = host + ":" + port; synchronized (connectionPoolMap) { Backend pool = connectionPoolMap.get(key); if (pool == null) { - pool = new Backend(host, port, Server.get().getServerDiscriminator(), listeners, new ConnectionPool(timer), distributionKey); + pool = new Backend(host, port, Server.get().getServerDiscriminator(), listeners, new ConnectionPool(timer)); connectionPoolMap.put(key, pool); } return pool; diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java index c93220f0a85..333d3970cc4 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java @@ -22,6 +22,7 @@ import com.yahoo.prelude.querytransform.QueryRewrite; import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; import com.yahoo.search.Result; +import com.yahoo.search.dispatch.CloseableChannel; import com.yahoo.search.dispatch.Dispatcher; import com.yahoo.search.dispatch.SearchCluster; import com.yahoo.search.grouping.GroupingRequest; @@ -62,6 +63,10 @@ public class FastSearcher extends VespaBackEndSearcher { /** The compression method which will be used with rpc dispatch. "lz4" (default) and "none" is supported. */ private final static CompoundName dispatchCompression = new CompoundName("dispatch.compression"); + /** If enabled, the dispatcher internal to the search container will be preferred over fdispatch + * whenever possible */ + private static final CompoundName dispatchInternal = new CompoundName("dispatch.internal"); + /** Used to dispatch directly to search nodes over RPC, replacing the old fnet communication path */ private final Dispatcher dispatcher; @@ -168,11 +173,9 @@ public class FastSearcher extends VespaBackEndSearcher { @Override public Result doSearch2(Query query, QueryPacket queryPacket, CacheKey cacheKey, Execution execution) { - FS4Channel channel = null; - try { - if (dispatcher.searchCluster().groupSize() == 1) - forceSinglePassGrouping(query); - channel = chooseBackend(query).openChannel(); + if (dispatcher.searchCluster().groupSize() == 1) + forceSinglePassGrouping(query); + try(CloseableChannel channel = getChannel(query)) { channel.setQuery(query); Result result = searchTwoPhase(channel, query, queryPacket, cacheKey); @@ -195,9 +198,6 @@ public class FastSearcher extends VespaBackEndSearcher { query.trace(getName() + " error response: " + result, false, 1); result.hits().addError(ErrorMessage.createBackendCommunicationError(getName() + " failed: "+ e.getMessage())); return result; - } finally { - if (channel != null) - channel.close(); } } @@ -214,24 +214,32 @@ public class FastSearcher extends VespaBackEndSearcher { } /** - * Returns the backend object to issue a search request over. - * Normally this is the backend field of this instance, which connects to the dispatch node this talk to - * (which is why this instance was chosen by the cluster controller). However, when certain conditions obtain - * (see below), we will instead return a backend instance which connects directly to the local search node - * for efficiency. + * Returns an interface object to issue a search request over. + * Normally this is built from the backend field of this instance, which connects to the dispatch node + * this component talks to (which is why this instance was chosen by the cluster controller). However, + * under certain conditions we will instead return an interface which connects directly to the relevant + * search nodes. */ - private Backend chooseBackend(Query query) { - if ( ! query.properties().getBoolean(dispatchDirect, true)) return dispatchBackend; - if (query.properties().getBoolean(com.yahoo.search.query.Model.ESTIMATE)) return dispatchBackend; + private CloseableChannel getChannel(Query query) { + if (query.properties().getBoolean(dispatchInternal, false)) { + Optional<CloseableChannel> dispatchedChannel = dispatcher.getDispatchedChannel(query); + if (dispatchedChannel.isPresent()) { + return dispatchedChannel.get(); + } + } + if (!query.properties().getBoolean(dispatchDirect, true)) + return new CloseableChannel(dispatchBackend); + if (query.properties().getBoolean(com.yahoo.search.query.Model.ESTIMATE)) + return new CloseableChannel(dispatchBackend); Optional<SearchCluster.Node> directDispatchRecipient = dispatcher.searchCluster().directDispatchTarget(); - if ( ! directDispatchRecipient.isPresent()) return dispatchBackend; + if (!directDispatchRecipient.isPresent()) + return new CloseableChannel(dispatchBackend); // Dispatch directly to the single, local search node query.trace(false, 2, "Dispatching directly to ", directDispatchRecipient.get()); - return fs4ResourcePool.getBackend(directDispatchRecipient.get().hostname(), - directDispatchRecipient.get().fs4port(), - Optional.of(directDispatchRecipient.get().key())); + return new CloseableChannel(fs4ResourcePool.getBackend(directDispatchRecipient.get().hostname(), + directDispatchRecipient.get().fs4port()), Optional.of(directDispatchRecipient.get().key())); } /** @@ -270,10 +278,9 @@ public class FastSearcher extends VespaBackEndSearcher { packetWrapper = cacheLookupTwoPhase(cacheKey, result, summaryClass); } - FS4Channel channel = chooseBackend(query).openChannel(); - channel.setQuery(query); Packet[] receivedPackets; - try { + try(CloseableChannel channel = getChannel(query)) { + channel.setQuery(query); DocsumPacketKey[] packetKeys; if (countFastHits(result) > 0) { @@ -340,8 +347,6 @@ public class FastSearcher extends VespaBackEndSearcher { query.trace(traceMsg, false, 3); } } - } finally { - channel.close(); } } @@ -373,7 +378,7 @@ public class FastSearcher extends VespaBackEndSearcher { return null; } - private Result searchTwoPhase(FS4Channel channel, Query query, QueryPacket queryPacket, CacheKey cacheKey) throws IOException { + private Result searchTwoPhase(CloseableChannel channel, Query query, QueryPacket queryPacket, CacheKey cacheKey) throws IOException { if (isLoggingFine()) getLogger().finest("sending query packet"); @@ -453,7 +458,7 @@ public class FastSearcher extends VespaBackEndSearcher { return packets; } - private Packet[] fetchSummaries(FS4Channel channel, Result result, String summaryClass) + private Packet[] fetchSummaries(CloseableChannel channel, Result result, String summaryClass) throws InvalidChannelException, ChannelTimeoutException, ClassCastException, IOException { BasicPacket[] receivedPackets; diff --git a/container-search/src/main/java/com/yahoo/prelude/query/CompositeItem.java b/container-search/src/main/java/com/yahoo/prelude/query/CompositeItem.java index 2c05f2e7edf..eee9949d831 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/CompositeItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/CompositeItem.java @@ -73,8 +73,7 @@ public abstract class CompositeItem extends Item { */ public void addItem(int index, Item item) { if (index > subitems.size() || index < 0) { - throw new IndexOutOfBoundsException( - "Could not add a subitem at position " + index + " to " + this); + throw new IndexOutOfBoundsException("Could not add a subitem at position " + index + " to " + this); } adding(item); subitems.add(index, item); diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NonReducibleCompositeItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NonReducibleCompositeItem.java index 547825cb51c..84aa177369a 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/NonReducibleCompositeItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/NonReducibleCompositeItem.java @@ -7,10 +7,9 @@ package com.yahoo.prelude.query; * <p> * Most composites, like AND and OR, are reducible as e.g (AND a) is semantically equal to (a). * <p> - * This type functions as a marked interfaces for query rewriters. + * This type functions as a marker type for query rewriters. * * @author bratseth - * @since 5.1.22 */ public abstract class NonReducibleCompositeItem extends CompositeItem { } diff --git a/container-search/src/main/java/com/yahoo/prelude/query/SameElementItem.java b/container-search/src/main/java/com/yahoo/prelude/query/SameElementItem.java index ca2c5a80283..aa446140da0 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/SameElementItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/SameElementItem.java @@ -11,6 +11,7 @@ import java.util.Iterator; * This represents a query where all terms are required to match in the same element id. * The primary usecase is to allow efficient search in arrays and maps of struct. * The common path is the field name containing the struct. + * * @author baldersheim */ @Beta diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/CloseableChannel.java b/container-search/src/main/java/com/yahoo/search/dispatch/CloseableChannel.java new file mode 100644 index 00000000000..643b8f81318 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/CloseableChannel.java @@ -0,0 +1,54 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import com.yahoo.fs4.BasicPacket; +import com.yahoo.fs4.ChannelTimeoutException; +import com.yahoo.fs4.mplex.Backend; +import com.yahoo.fs4.mplex.FS4Channel; +import com.yahoo.fs4.mplex.InvalidChannelException; +import com.yahoo.search.Query; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Optional; + +/** + * @author ollivir + */ +public class CloseableChannel implements Closeable { + private FS4Channel channel; + private final Optional<Integer> distributionKey; + + public CloseableChannel(Backend backend) { + this(backend, Optional.empty()); + } + + public CloseableChannel(Backend backend, Optional<Integer> distributionKey) { + this.channel = backend.openChannel(); + this.distributionKey = distributionKey; + } + + public void setQuery(Query query) { + channel.setQuery(query); + } + + public boolean sendPacket(BasicPacket packet) throws InvalidChannelException, IOException { + return channel.sendPacket(packet); + } + + public BasicPacket[] receivePackets(long timeout, int packetCount) throws InvalidChannelException, ChannelTimeoutException { + return channel.receivePackets(timeout, packetCount); + } + + public Optional<Integer> distributionKey() { + return distributionKey; + } + + @Override + public void close() { + if (channel != null) { + channel.close(); + channel = null; + } + } +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/DispatchedChannel.java b/container-search/src/main/java/com/yahoo/search/dispatch/DispatchedChannel.java new file mode 100644 index 00000000000..d005d9491d5 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/DispatchedChannel.java @@ -0,0 +1,38 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import com.yahoo.prelude.fastsearch.FS4ResourcePool; +import com.yahoo.search.dispatch.SearchCluster.Group; +import com.yahoo.search.dispatch.SearchCluster.Node; + +import java.util.Optional; + +/** + * An extension to CloseableChannel that encapsulates the release of a LoadBalancer group allocation. + * + * @author ollivir + */ +public class DispatchedChannel extends CloseableChannel { + private final SearchCluster.Group group; + private final LoadBalancer loadBalancer; + private boolean groupAllocated = true; + + public DispatchedChannel(FS4ResourcePool fs4ResourcePool, LoadBalancer loadBalancer, Group group, Node node) { + super(fs4ResourcePool.getBackend(node.hostname(), node.fs4port()), Optional.of(node.key())); + + this.loadBalancer = loadBalancer; + this.group = group; + } + + public DispatchedChannel(FS4ResourcePool fs4ResourcePool, LoadBalancer loadBalancer, Group group) { + this(fs4ResourcePool, loadBalancer, group, group.nodes().iterator().next()); + } + + public void close() { + if (groupAllocated) { + groupAllocated = false; + loadBalancer.releaseGroup(group); + } + super.close(); + } +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java index 5ef81403f26..c383b681558 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java @@ -28,6 +28,7 @@ import com.yahoo.vespa.config.search.DispatchConfig; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -57,10 +58,15 @@ public class Dispatcher extends AbstractComponent { private final Compressor compressor = new Compressor(); + private final LoadBalancer loadBalancer; + private final FS4ResourcePool fs4ResourcePool; + public Dispatcher(DispatchConfig dispatchConfig, FS4ResourcePool fs4ResourcePool, int containerClusterSize, VipStatus vipStatus) { this.client = new RpcClient(); this.searchCluster = new SearchCluster(dispatchConfig, fs4ResourcePool, containerClusterSize, vipStatus); + this.fs4ResourcePool = fs4ResourcePool; + this.loadBalancer = new LoadBalancer(searchCluster); // Create node rpc connections, indexed by the node distribution key ImmutableMap.Builder<Integer, Client.NodeConnection> nodeConnectionsBuilder = new ImmutableMap.Builder<>(); @@ -75,6 +81,8 @@ public class Dispatcher extends AbstractComponent { this.searchCluster = null; this.nodeConnections = ImmutableMap.copyOf(nodeConnections); this.client = client; + this.fs4ResourcePool = null; + this.loadBalancer = new LoadBalancer(searchCluster); } /** Returns the search cluster this dispatches to */ @@ -275,4 +283,18 @@ public class Dispatcher extends AbstractComponent { } + public Optional<CloseableChannel> getDispatchedChannel(Query query) { + Optional<SearchCluster.Group> groupInCluster = loadBalancer.takeGroupForQuery(query); + + return groupInCluster.flatMap(group -> { + if(group.nodes().size() == 1) { + SearchCluster.Node node = group.nodes().iterator().next(); + query.trace(false, 2, "Dispatching internally to ", group, " (", node.toString(), ")"); + return Optional.of(new DispatchedChannel(fs4ResourcePool, loadBalancer, group)); + } else { + loadBalancer.releaseGroup(group); + return Optional.empty(); + } + }); + } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java b/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java new file mode 100644 index 00000000000..d8e12980472 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java @@ -0,0 +1,138 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import com.yahoo.search.Query; +import com.yahoo.search.dispatch.SearchCluster.Group; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * LoadBalancer determines which group of content nodes should be accessed next for each search query when the internal java dispatcher is + * used. + * + * @author ollivir + */ +public class LoadBalancer { + // The implementation here is a simplistic least queries in flight + round-robin load balancer + // TODO: consider the options in com.yahoo.vespa.model.content.TuningDispatch + + private final static Logger log = Logger.getLogger(LoadBalancer.class.getName()); + + private final boolean isInternallyDispatchable; + private final List<GroupSchedule> scoreboard; + private int needle = 0; + + public LoadBalancer(SearchCluster searchCluster) { + if (searchCluster == null) { + this.isInternallyDispatchable = false; + this.scoreboard = null; + return; + } + this.isInternallyDispatchable = (searchCluster.groupSize() == 1); + this.scoreboard = new ArrayList<>(searchCluster.groups().size()); + + for (Group group : searchCluster.groups().values()) { + scoreboard.add(new GroupSchedule(group)); + } + Collections.shuffle(scoreboard); + } + + /** + * Select and allocate the search cluster group which is to be used for the provided query. Callers <b>must</b> call + * {@link #releaseGroup(Group)} symmetrically for each taken allocation. + * + * @param query + * @return The node group to target, or <i>empty</i> if the internal dispatch logic cannot be used + */ + public Optional<Group> takeGroupForQuery(Query query) { + if (!isInternallyDispatchable) { + return Optional.empty(); + } + + return allocateNextGroup(); + } + + /** + * Release an allocation given by {@link #takeGroupForQuery(Query)}. The release must be done exactly once for each allocation. + * + * @param group + * previously allocated group + */ + public void releaseGroup(Group group) { + synchronized (this) { + for (GroupSchedule sched : scoreboard) { + if (sched.group.id() == group.id()) { + sched.adjustScore(-1); + break; + } + } + } + } + + private Optional<Group> allocateNextGroup() { + synchronized (this) { + GroupSchedule bestSchedule = null; + + int index = needle; + for (int i = 0; i < scoreboard.size(); i++) { + GroupSchedule sched = scoreboard.get(index); + if (sched.isPreferredOver(bestSchedule)) { + bestSchedule = sched; + } + index = nextScoreboardIndex(index); + } + needle = nextScoreboardIndex(needle); + + Group ret = null; + if (bestSchedule != null) { + bestSchedule.adjustScore(1); + ret = bestSchedule.group; + } + if (log.isLoggable(Level.FINE)) { + log.fine("Offering <" + ret + "> for query connection"); + } + return Optional.ofNullable(ret); + } + } + + private int nextScoreboardIndex(int current) { + int next = current + 1; + if (next >= scoreboard.size()) { + next %= scoreboard.size(); + } + return next; + } + + private static class GroupSchedule { + private final Group group; + private int score; + + public GroupSchedule(Group group) { + this.group = group; + this.score = 0; + } + + public boolean isPreferredOver(GroupSchedule other) { + if (! group.hasSufficientCoverage()) { + return false; + } + if (other == null) { + return true; + } + return this.score < other.score; + } + + public void adjustScore(int amount) { + this.score += amount; + if (score < 0) { + log.warning("Double free of query target group detected"); + score = 0; + } + } + } +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/RpcClient.java b/container-search/src/main/java/com/yahoo/search/dispatch/RpcClient.java index 8107a50e8c0..67e032eca37 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/RpcClient.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/RpcClient.java @@ -15,7 +15,6 @@ import com.yahoo.jrt.Values; import com.yahoo.prelude.fastsearch.FastHit; import java.util.List; -import java.util.concurrent.atomic.AtomicReference; /** * A client which uses rpc request to search nodes to implement the Client API. diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/SearchCluster.java b/container-search/src/main/java/com/yahoo/search/dispatch/SearchCluster.java index efce2fdac9c..48ddba6c301 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/SearchCluster.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/SearchCluster.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.dispatch; -import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableCollection; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/container-search/src/main/java/com/yahoo/search/grouping/GroupingValidator.java b/container-search/src/main/java/com/yahoo/search/grouping/GroupingValidator.java index d96f490909e..06b030dbc78 100644 --- a/container-search/src/main/java/com/yahoo/search/grouping/GroupingValidator.java +++ b/container-search/src/main/java/com/yahoo/search/grouping/GroupingValidator.java @@ -5,6 +5,7 @@ import com.google.inject.Inject; import com.yahoo.component.chain.dependencies.After; import com.yahoo.component.chain.dependencies.Before; import com.yahoo.component.chain.dependencies.Provides; +import com.yahoo.search.grouping.request.AttributeMapLookupValue; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.container.QrSearchersConfig; import com.yahoo.processing.request.CompoundName; @@ -18,8 +19,7 @@ import com.yahoo.search.grouping.request.GroupingExpression; import com.yahoo.search.searchchain.Execution; import com.yahoo.search.searchchain.PhaseNames; -import java.util.HashSet; -import java.util.Set; +import java.util.HashMap; import static com.yahoo.search.grouping.GroupingQueryParser.SELECT_PARAMETER_PARSING; @@ -37,7 +37,7 @@ public class GroupingValidator extends Searcher { public static final String GROUPING_VALIDATED = "GroupingValidated"; public static final CompoundName PARAM_ENABLED = new CompoundName("validate_" + GroupingQueryParser.PARAM_REQUEST); - private final Set<String> attributeNames = new HashSet<>(); + private final HashMap<String, AttributesConfig.Attribute> attributes = new HashMap<>(); private final String clusterName; private final boolean enabled; @@ -55,7 +55,7 @@ public class GroupingValidator extends Searcher { enabled = (indexingMode != QrSearchersConfig.Searchcluster.Indexingmode.STREAMING); clusterName = enabled ? qrsConfig.searchcluster(clusterId).name() : null; for (AttributesConfig.Attribute attr : attributesConfig.attribute()) { - attributeNames.add(attr.name()); + attributes.put(attr.name(), attr); } } @@ -69,15 +69,42 @@ public class GroupingValidator extends Searcher { return execution.search(query); } + private void verifyHasAttribute(String attributeName) { + if (!attributes.containsKey(attributeName)) { + throw new UnavailableAttributeException(clusterName, attributeName); + } + } + + private void verifyCompatibleAttributeTypes(String keyAttributeName, + String keySourceAttributeName) { + AttributesConfig.Attribute keyAttribute = attributes.get(keyAttributeName); + AttributesConfig.Attribute keySourceAttribute = attributes.get(keySourceAttributeName); + if (!keySourceAttribute.datatype().equals(keyAttribute.datatype())) { + throw new IllegalArgumentException("Grouping request references key source attribute '" + + keySourceAttributeName + "' with data type '" + keySourceAttribute.datatype() + + "' that is different than data type '" + keyAttribute.datatype() + "' of key attribute '" + + keyAttributeName + "'"); + } + if (!keySourceAttribute.collectiontype().equals(AttributesConfig.Attribute.Collectiontype.Enum.SINGLE)) { + throw new IllegalArgumentException("Grouping request references key source attribute '" + + keySourceAttributeName + "' which is not of single value type"); + } + } + private class MyVisitor implements ExpressionVisitor { @Override public void visitExpression(GroupingExpression exp) { - if (exp instanceof AttributeValue) { - String name = ((AttributeValue)exp).getAttributeName(); - if (!attributeNames.contains(name)) { - throw new UnavailableAttributeException(clusterName, name); + if (exp instanceof AttributeMapLookupValue) { + AttributeMapLookupValue mapLookup = (AttributeMapLookupValue) exp; + verifyHasAttribute(mapLookup.getKeyAttribute()); + verifyHasAttribute(mapLookup.getValueAttribute()); + if (mapLookup.hasKeySourceAttribute()) { + verifyHasAttribute(mapLookup.getKeySourceAttribute()); + verifyCompatibleAttributeTypes(mapLookup.getKeyAttribute(), mapLookup.getKeySourceAttribute()); } + } else if (exp instanceof AttributeValue) { + verifyHasAttribute(((AttributeValue) exp).getAttributeName()); } } } diff --git a/container-search/src/main/java/com/yahoo/search/grouping/request/AttributeMapLookupValue.java b/container-search/src/main/java/com/yahoo/search/grouping/request/AttributeMapLookupValue.java new file mode 100644 index 00000000000..281e6e53b36 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/grouping/request/AttributeMapLookupValue.java @@ -0,0 +1,62 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.grouping.request; + +/** + * This class represents a lookup in a map attribute in a {@link GroupingExpression}. + * + * It evaluates to the value found using the given key for the lookup in that attribute. + * The key is either specified explicitly or found via a key source attribute. + * Two underlying attributes are used to represent the map attribute (the key and value attributes). + * + * @author geirst + */ +public class AttributeMapLookupValue extends AttributeValue { + + private final String prefix; + private final String suffix; + private final String key; + private final String keySourceAttribute; + + private AttributeMapLookupValue(String attributeValue, String prefix, String suffix, String key, String keySourceAttribute) { + super(attributeValue); + this.prefix = prefix; + this.suffix = suffix; + this.key = key; + this.keySourceAttribute = keySourceAttribute; + } + + public static AttributeMapLookupValue fromKey(String prefix, String key, String suffix) { + return new AttributeMapLookupValue(prefix + "{\"" + key + "\"}" + suffix, + prefix, suffix, key, ""); + } + + public static AttributeMapLookupValue fromKeySourceAttribute(String prefix, String keySourceAttribute, String suffix) { + return new AttributeMapLookupValue(prefix + "{attribute(" + keySourceAttribute + ")}" + suffix, + prefix, suffix, "", keySourceAttribute); + } + + @Override + public AttributeMapLookupValue copy() { + return new AttributeMapLookupValue(getAttributeName(), prefix, suffix, key, keySourceAttribute); + } + + public String getKeyAttribute() { + return prefix + ".key"; + } + + public String getValueAttribute() { + return prefix + ".value" + suffix; + } + + public String getKey() { + return key; + } + + public boolean hasKeySourceAttribute() { + return !keySourceAttribute.isEmpty(); + } + + public String getKeySourceAttribute() { + return keySourceAttribute; + } +} diff --git a/container-search/src/main/java/com/yahoo/search/grouping/vespa/ExpressionConverter.java b/container-search/src/main/java/com/yahoo/search/grouping/vespa/ExpressionConverter.java index d2dfb3c0ee7..d017fe2edb3 100644 --- a/container-search/src/main/java/com/yahoo/search/grouping/vespa/ExpressionConverter.java +++ b/container-search/src/main/java/com/yahoo/search/grouping/vespa/ExpressionConverter.java @@ -6,6 +6,7 @@ import com.yahoo.search.grouping.request.AggregatorNode; import com.yahoo.search.grouping.request.AndFunction; import com.yahoo.search.grouping.request.ArrayAtLookup; import com.yahoo.search.grouping.request.AttributeFunction; +import com.yahoo.search.grouping.request.AttributeMapLookupValue; import com.yahoo.search.grouping.request.AttributeValue; import com.yahoo.search.grouping.request.AvgAggregator; import com.yahoo.search.grouping.request.BucketValue; @@ -101,6 +102,7 @@ import com.yahoo.searchlib.expression.AddFunctionNode; import com.yahoo.searchlib.expression.AggregationRefNode; import com.yahoo.searchlib.expression.AndFunctionNode; import com.yahoo.searchlib.expression.ArrayAtLookupNode; +import com.yahoo.searchlib.expression.AttributeMapLookupNode; import com.yahoo.searchlib.expression.AttributeNode; import com.yahoo.searchlib.expression.BucketResultNode; import com.yahoo.searchlib.expression.CatFunctionNode; @@ -263,6 +265,16 @@ class ExpressionConverter { if (exp instanceof AndFunction) { return addArguments(new AndFunctionNode(), (AndFunction)exp); } + if (exp instanceof AttributeMapLookupValue) { + AttributeMapLookupValue mapLookup = (AttributeMapLookupValue) exp; + if (mapLookup.hasKeySourceAttribute()) { + return AttributeMapLookupNode.fromKeySourceAttribute(mapLookup.getAttributeName(), + mapLookup.getKeyAttribute(), mapLookup.getValueAttribute(), mapLookup.getKeySourceAttribute()); + } else { + return AttributeMapLookupNode.fromKey(mapLookup.getAttributeName(), + mapLookup.getKeyAttribute(), mapLookup.getValueAttribute(), mapLookup.getKey()); + } + } if (exp instanceof AttributeValue) { return new AttributeNode(((AttributeValue)exp).getAttributeName()); } diff --git a/container-search/src/main/java/com/yahoo/search/query/rewrite/QueryRewriteSearcher.java b/container-search/src/main/java/com/yahoo/search/query/rewrite/QueryRewriteSearcher.java index 2d0ff0c62db..7b24a00cf60 100644 --- a/container-search/src/main/java/com/yahoo/search/query/rewrite/QueryRewriteSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/query/rewrite/QueryRewriteSearcher.java @@ -36,14 +36,13 @@ import java.util.logging.Logger; public abstract class QueryRewriteSearcher extends Searcher { // Indicate whether rewriter is properly initiated - private boolean isOk = false; + private boolean isOk; protected final Logger logger = Logger.getLogger(QueryRewriteSearcher.class.getName()); // HashMap which store the rewriter dicts // It has the following format: - // HashMap<String(e.g. dictionary name, etc), - // Object(e.g. FSA, etc)>> + // HashMap<String(e.g. dictionary name, etc), Object(e.g. FSA, etc)>> protected HashMap<String, Object> rewriterDicts = new HashMap<>(); /** @@ -201,14 +200,14 @@ public abstract class QueryRewriteSearcher extends Searcher { "FSA file location for " + fsaName + ": " + fsaPath); // Retrieve FSA File handler - File fsaFile = null; - if(fileAcquirer!=null) { + File fsaFile; + if (fileAcquirer != null) { fsaFile = fileAcquirer.waitFor(fsaPath, 5, TimeUnit.MINUTES); - } else if(fileList!=null) { + } else { fsaFile = fileList.get(fsaName); } - if(fsaFile==null) { + if (fsaFile == null) { RewriterUtils.error(logger, "Error loading FSA dictionary file handler"); return false; } diff --git a/container-search/src/main/java/com/yahoo/search/yql/TypeCheckers.java b/container-search/src/main/java/com/yahoo/search/yql/TypeCheckers.java index cb693d6801d..af54d28c2ac 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/TypeCheckers.java +++ b/container-search/src/main/java/com/yahoo/search/yql/TypeCheckers.java @@ -64,7 +64,8 @@ final class TypeCheckers { TypeLiteral<?> arg = TypeLiteral.get(type.getActualTypeArguments()[0]); if (OperatorNode.class.isAssignableFrom(arg.getRawType())) { Preconditions.checkArgument(arg.getType() instanceof ParameterizedType, "Type spec must be List<OperatorNode<?>>"); - Class<? extends Operator> optype = (Class<? extends Operator>) TypeLiteral.get(((ParameterizedType) arg.getType()).getActualTypeArguments()[0]).getRawType(); + Class<?> rawType = (Class<?>) TypeLiteral.get(((ParameterizedType) arg.getType()).getActualTypeArguments()[0]).getRawType(); + Class<? extends Operator> optype = (Class<? extends Operator>) rawType; return new OperatorNodeListTypeChecker(parent, idx, optype, ImmutableSet.<Operator>of()); } else { return new JavaListTypeChecker(parent, idx, arg.getRawType()); diff --git a/container-search/src/main/javacc/com/yahoo/search/grouping/request/parser/GroupingParser.jj b/container-search/src/main/javacc/com/yahoo/search/grouping/request/parser/GroupingParser.jj index 0678b030bc5..6a55a32eb8a 100644 --- a/container-search/src/main/javacc/com/yahoo/search/grouping/request/parser/GroupingParser.jj +++ b/container-search/src/main/javacc/com/yahoo/search/grouping/request/parser/GroupingParser.jj @@ -404,14 +404,31 @@ AndFunction andFunction(GroupingOperation grp) : AttributeValue attributeValue() : { - StringBuilder ret = new StringBuilder(); + StringBuilder prefix = new StringBuilder(); + StringBuilder suffix = new StringBuilder(); String str; + String key = null; + AttributeFunction keySourceAttr = null; } { - ( str = identifier() { ret.append(str); } - ( ( <DOT> { ret.append(token.image); } ( str = identifier() { ret.append(str); } ) ) | - ( lcurly() str = string() { ret.append("{\"").append(str).append("\"}"); } rcurly() ) )* ) - { return new AttributeValue(ret.toString()); } + ( str = identifier() { prefix.append(str); } + ( LOOKAHEAD(2) <DOT> { prefix.append(token.image); } ( str = identifier() { prefix.append(str); } ) )* + ( LOOKAHEAD(3) + ( lcurly() key = string() rcurly() ) | + ( lcurly() keySourceAttr = attributeFunction() rcurly() ) + )? + ( <DOT> { suffix.append(token.image); } ( str = identifier() { suffix.append(str); } ) )* + ) + { + if (key != null) { + return AttributeMapLookupValue.fromKey(prefix.toString(), key, suffix.toString()); + } else if (keySourceAttr != null) { + return AttributeMapLookupValue.fromKeySourceAttribute(prefix.toString(), keySourceAttr.getAttributeName(), suffix.toString()); + } else { + prefix.append(suffix.toString()); + return new AttributeValue(prefix.toString()); + } + } } AttributeFunction attributeFunction() : diff --git a/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTester.java b/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTester.java index 4f99f06986a..4f6d2d88917 100644 --- a/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTester.java +++ b/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTester.java @@ -52,7 +52,7 @@ class FastSearcherTester { vipStatus = new VipStatus(clustersStatus); mockFS4ResourcePool = new MockFS4ResourcePool(); mockDispatcher = new MockDispatcher(searchNodes, mockFS4ResourcePool, containerClusterSize, vipStatus); - fastSearcher = new FastSearcher(new MockBackend(Optional.empty(), selfHostname, 0L, true), + fastSearcher = new FastSearcher(new MockBackend(selfHostname, 0L, true), mockFS4ResourcePool, mockDispatcher, new SummaryParameters(null), diff --git a/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/fs4mock/MockBackend.java b/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/fs4mock/MockBackend.java index 01ae9aa8f33..29b28112797 100644 --- a/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/fs4mock/MockBackend.java +++ b/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/fs4mock/MockBackend.java @@ -4,9 +4,6 @@ package com.yahoo.prelude.fastsearch.test.fs4mock; import com.yahoo.fs4.mplex.Backend; import com.yahoo.fs4.mplex.FS4Channel; -import java.util.Optional; -import java.util.function.Supplier; - /** * @author bratseth */ @@ -20,11 +17,11 @@ public class MockBackend extends Backend { private MockFSChannel channel = null; public MockBackend() { - this(Optional.empty(), "", 0L, true); + this("", 0L, true); } - public MockBackend(Optional<Integer> distributionKey, String hostname, long activeDocumentsInBackend, boolean working) { - super(distributionKey); + public MockBackend(String hostname, long activeDocumentsInBackend, boolean working) { + super(); this.hostname = hostname; this.activeDocumentsInBackend = activeDocumentsInBackend; this.working = working; diff --git a/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/fs4mock/MockFS4ResourcePool.java b/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/fs4mock/MockFS4ResourcePool.java index 9b5f4b99f20..7bb161acc07 100644 --- a/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/fs4mock/MockFS4ResourcePool.java +++ b/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/fs4mock/MockFS4ResourcePool.java @@ -26,12 +26,12 @@ public class MockFS4ResourcePool extends FS4ResourcePool { } @Override - public Backend getBackend(String hostname, int port, Optional<Integer> distributionKey) { + public Backend getBackend(String hostname, int port) { countRequest(hostname + ":" + port); if (nonRespondingBackends.contains(hostname)) - return new MockBackend(distributionKey, hostname, 0L, false); + return new MockBackend(hostname, 0L, false); else - return new MockBackend(distributionKey, hostname, activeDocumentsInBackend.getOrDefault(hostname, 0L), true); + return new MockBackend(hostname, activeDocumentsInBackend.getOrDefault(hostname, 0L), true); } /** diff --git a/container-search/src/test/java/com/yahoo/prelude/hitfield/test/JSONStringTestCase.java b/container-search/src/test/java/com/yahoo/prelude/hitfield/test/JSONStringTestCase.java index 09a439c7bc9..18231785a26 100644 --- a/container-search/src/test/java/com/yahoo/prelude/hitfield/test/JSONStringTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/hitfield/test/JSONStringTestCase.java @@ -329,8 +329,8 @@ public class JSONStringTestCase { String rendered = js.renderFromInspector(); assertTrue(-1 < rendered.indexOf(f1)); - int offsetF2; - assertTrue(-1 < (offsetF2 = rendered.indexOf(f2))); + int offsetF2 = rendered.indexOf(f2); + assertTrue(-1 < offsetF2); offsetF2 += f2.length(); assertTrue(-1 < rendered.indexOf(f2_1, offsetF2)); assertTrue(-1 < rendered.indexOf(f2_2, offsetF2)); diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/FillTestCase.java b/container-search/src/test/java/com/yahoo/search/dispatch/FillTestCase.java index 0191b1a799b..5e3e0dc301e 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/FillTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/FillTestCase.java @@ -2,7 +2,6 @@ package com.yahoo.search.dispatch; import com.yahoo.compress.CompressionType; -import com.yahoo.log.event.Collection; import com.yahoo.prelude.fastsearch.DocsumDefinition; import com.yahoo.prelude.fastsearch.DocsumDefinitionSet; import com.yahoo.prelude.fastsearch.DocsumField; diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java new file mode 100644 index 00000000000..2ba991310f5 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java @@ -0,0 +1,117 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import com.yahoo.search.dispatch.SearchCluster.Group; +import com.yahoo.search.dispatch.SearchCluster.Node; +import junit.framework.AssertionFailedError; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Optional; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +/** + * @author ollivir + */ +public class LoadBalancerTest { + @Test + public void requreThatLoadBalancerServesSingleNodeSetups() { + Node n1 = new SearchCluster.Node(0, "test-node1", 0, 0); + SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1), null, 1, null); + LoadBalancer lb = new LoadBalancer(cluster); + + Optional<Group> grp = lb.takeGroupForQuery(null); + Group group = grp.orElseGet(() -> { + throw new AssertionFailedError("Expected a SearchCluster.Group"); + }); + assertThat(group.nodes().size(), equalTo(1)); + } + + @Test + public void requreThatLoadBalancerServesMultiGroupSetups() { + Node n1 = new SearchCluster.Node(0, "test-node1", 0, 0); + Node n2 = new SearchCluster.Node(1, "test-node2", 1, 1); + SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1, n2), null, 1, null); + LoadBalancer lb = new LoadBalancer(cluster); + + Optional<Group> grp = lb.takeGroupForQuery(null); + Group group = grp.orElseGet(() -> { + throw new AssertionFailedError("Expected a SearchCluster.Group"); + }); + assertThat(group.nodes().size(), equalTo(1)); + } + + @Test + public void requreThatLoadBalancerIgnoresClusteredSingleGroup() { + Node n1 = new SearchCluster.Node(0, "test-node1", 0, 0); + Node n2 = new SearchCluster.Node(1, "test-node2", 1, 0); + SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1, n2), null, 2, null); + LoadBalancer lb = new LoadBalancer(cluster); + + Optional<Group> grp = lb.takeGroupForQuery(null); + assertThat(grp.isPresent(), is(false)); + } + + @Test + public void requreThatLoadBalancerIgnoresClusteredGroups() { + Node n1 = new SearchCluster.Node(0, "test-node1", 0, 0); + Node n2 = new SearchCluster.Node(1, "test-node2", 1, 0); + Node n3 = new SearchCluster.Node(0, "test-node3", 0, 1); + Node n4 = new SearchCluster.Node(1, "test-node4", 1, 1); + SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1, n2, n3, n4), null, 2, null); + LoadBalancer lb = new LoadBalancer(cluster); + + Optional<Group> grp = lb.takeGroupForQuery(null); + assertThat(grp.isPresent(), is(false)); + } + + @Test + public void requreThatLoadBalancerReturnsDifferentGroups() { + Node n1 = new SearchCluster.Node(0, "test-node1", 0, 0); + Node n2 = new SearchCluster.Node(1, "test-node2", 1, 1); + SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1, n2), null, 1, null); + LoadBalancer lb = new LoadBalancer(cluster); + + // get first group + Optional<Group> grp = lb.takeGroupForQuery(null); + Group group = grp.get(); + int id1 = group.id(); + // release allocation + lb.releaseGroup(group); + + // get second group + grp = lb.takeGroupForQuery(null); + group = grp.get(); + assertThat(group.id(), not(equalTo(id1))); + } + + @Test + public void requreThatLoadBalancerReturnsGroupWithShortestQueue() { + Node n1 = new SearchCluster.Node(0, "test-node1", 0, 0); + Node n2 = new SearchCluster.Node(1, "test-node2", 1, 1); + SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1, n2), null, 1, null); + LoadBalancer lb = new LoadBalancer(cluster); + + // get first group + Optional<Group> grp = lb.takeGroupForQuery(null); + Group group = grp.get(); + int id1 = group.id(); + + // get second group + grp = lb.takeGroupForQuery(null); + group = grp.get(); + int id2 = group.id(); + assertThat(id2, not(equalTo(id1))); + // release second allocation + lb.releaseGroup(group); + + // get third group + grp = lb.takeGroupForQuery(null); + group = grp.get(); + assertThat(group.id(), equalTo(id2)); + } +} diff --git a/container-search/src/test/java/com/yahoo/search/grouping/GroupingValidatorTestCase.java b/container-search/src/test/java/com/yahoo/search/grouping/GroupingValidatorTestCase.java index 82c05c1d995..9723f96af27 100644 --- a/container-search/src/test/java/com/yahoo/search/grouping/GroupingValidatorTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/grouping/GroupingValidatorTestCase.java @@ -7,51 +7,164 @@ import com.yahoo.search.Query; import com.yahoo.search.config.ClusterConfig; import com.yahoo.search.grouping.request.GroupingOperation; import com.yahoo.search.searchchain.Execution; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import java.util.Arrays; import java.util.Collection; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; - /** * @author Simon Thoresen Hult */ public class GroupingValidatorTestCase { + @Rule + public ExpectedException thrown = ExpectedException.none(); + @Test public void requireThatAvailableAttributesDoNotThrow() { - Query query = new Query(); - GroupingRequest req = GroupingRequest.newInstance(query); - req.setRootOperation(GroupingOperation.fromString("all(group(foo) each(output(max(bar))))")); - validateGrouping("myCluster", Arrays.asList("foo", "bar"), query); + validateGrouping(Arrays.asList("foo", "bar"), + "all(group(foo) each(output(max(bar))))");; } @Test public void requireThatUnavailableAttributesThrow() { - Query query = new Query(); - GroupingRequest req = GroupingRequest.newInstance(query); - req.setRootOperation(GroupingOperation.fromString("all(group(foo) each(output(max(bar))))")); - try { - validateGrouping("myCluster", Arrays.asList("foo"), query); - fail("Validator should throw exception because attribute 'bar' is unavailable."); - } catch (UnavailableAttributeException e) { - assertEquals("myCluster", e.getClusterName()); - assertEquals("bar", e.getAttributeName()); - } + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("bar")); + validateGrouping(Arrays.asList("foo"), + "all(group(foo) each(output(max(bar))))"); } @Test public void requireThatEnableFlagPreventsThrow() { + Query query = createQuery("all(group(foo) each(output(max(bar))))"); + query.properties().set(GroupingValidator.PARAM_ENABLED, "false"); + validateGrouping(Arrays.asList("foo"), query); + } + + @Test + public void available_primitive_map_attribute_does_not_throw() { + validateGrouping(Arrays.asList("map.key", "map.value"), + "all(group(map{\"foo\"}) each(output(count())))"); + } + + @Test + public void unavailable_primitive_map_key_attribute_throws() { + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("map.key")); + validateGrouping(Arrays.asList("null"), + "all(group(map{\"foo\"}) each(output(count())))"); + } + + @Test + public void unavailable_primitive_map_value_attribute_throws() { + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("map.value")); + validateGrouping(Arrays.asList("map.key"), + "all(group(map{\"foo\"}) each(output(count())))"); + } + + @Test + public void available_struct_map_attribute_does_not_throw() { + validateGrouping(Arrays.asList("map.key", "map.value.name"), + "all(group(map{\"foo\"}.name) each(output(count())))"); + } + + @Test + public void unavailable_struct_map_key_attribute_throws() { + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("map.key")); + validateGrouping(Arrays.asList("null"), + "all(group(map{\"foo\"}.name) each(output(count())))"); + } + + @Test + public void unavailable_struct_map_value_attribute_throws() { + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("map.value.name")); + validateGrouping(Arrays.asList("map.key"), + "all(group(map{\"foo\"}.name) each(output(count())))"); + } + + @Test + public void available_key_source_attribute_does_not_throw() { + validateGrouping(Arrays.asList("map.key", "map.value", "key_source"), + "all(group(map{attribute(key_source)}) each(output(count())))"); + } + + @Test + public void unavailable_key_source_attribute_throws() { + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("key_source")); + validateGrouping(Arrays.asList("map.key", "map.value"), + "all(group(map{attribute(key_source)}) each(output(count())))"); + } + + @Test + public void key_source_attribute_with_mismatching_data_type_throws() { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Grouping request references key source attribute 'key_source' with data type 'INT32' " + + "that is different than data type 'STRING' of key attribute 'map.key'"); + + validateGrouping(setupMismatchingKeySourceAttribute(false), + "all(group(map{attribute(key_source)}) each(output(count())))"); + } + + @Test + public void key_source_attribute_with_multi_value_collection_type_throws() { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Grouping request references key source attribute 'key_source' which is not of single value type"); + + validateGrouping(setupMismatchingKeySourceAttribute(true), + "all(group(map{attribute(key_source)}) each(output(count())))"); + } + + private static AttributesConfig setupMismatchingKeySourceAttribute(boolean matchingDataType) { + AttributesConfig.Builder builder = new AttributesConfig.Builder(); + builder.attribute(new AttributesConfig.Attribute.Builder().name("map.key") + .datatype(AttributesConfig.Attribute.Datatype.Enum.STRING)); + builder.attribute(new AttributesConfig.Attribute.Builder().name("map.value")); + builder.attribute(new AttributesConfig.Attribute.Builder().name("key_source") + .datatype(matchingDataType ? AttributesConfig.Attribute.Datatype.Enum.STRING : + AttributesConfig.Attribute.Datatype.Enum.INT32) + .collectiontype(AttributesConfig.Attribute.Collectiontype.Enum.ARRAY)); + return new AttributesConfig(builder); + } + + private static String createMessage(String attributeName) { + return "Grouping request references attribute '" + attributeName + "' which is not available in cluster 'myCluster'."; + } + + private static Query createQuery(String groupingExpression) { Query query = new Query(); GroupingRequest req = GroupingRequest.newInstance(query); - req.setRootOperation(GroupingOperation.fromString("all(group(foo) each(output(max(bar))))")); - query.properties().set(GroupingValidator.PARAM_ENABLED, "false"); - validateGrouping("myCluster", Arrays.asList("foo"), query); + req.setRootOperation(GroupingOperation.fromString(groupingExpression)); + return query; + } + + private static AttributesConfig createAttributesConfig(Collection<String> attributeNames) { + AttributesConfig.Builder builder = new AttributesConfig.Builder(); + for (String attributeName : attributeNames) { + builder.attribute(new AttributesConfig.Attribute.Builder() + .name(attributeName)); + } + return new AttributesConfig(builder); + } + + private static void validateGrouping(Collection<String> attributeNames, String groupingExpression) { + validateGrouping("myCluster", createAttributesConfig(attributeNames), createQuery(groupingExpression)); + } + + private static void validateGrouping(AttributesConfig attributesCfg, String groupingExpression) { + validateGrouping("myCluster", attributesCfg, createQuery(groupingExpression)); } - private static void validateGrouping(String clusterName, Collection<String> attributeNames, Query query) { + private static void validateGrouping(Collection<String> attributeNames, Query query) { + validateGrouping("myCluster", createAttributesConfig(attributeNames), query); + } + + private static void validateGrouping(String clusterName, AttributesConfig attributesConfig, Query query) { QrSearchersConfig.Builder qrsConfig = new QrSearchersConfig.Builder().searchcluster( new QrSearchersConfig.Searchcluster.Builder() .indexingmode(QrSearchersConfig.Searchcluster.Indexingmode.Enum.REALTIME) @@ -59,15 +172,10 @@ public class GroupingValidatorTestCase { ClusterConfig.Builder clusterConfig = new ClusterConfig.Builder(). clusterId(0). clusterName("test"); - AttributesConfig.Builder attributesConfig = new AttributesConfig.Builder(); - for (String attributeName : attributeNames) { - attributesConfig.attribute(new AttributesConfig.Attribute.Builder() - .name(attributeName)); - } new Execution( new GroupingValidator(new QrSearchersConfig(qrsConfig), - new ClusterConfig(clusterConfig), - new AttributesConfig(attributesConfig)), + new ClusterConfig(clusterConfig), + attributesConfig), Execution.Context.createContextStub()).search(query); } } diff --git a/container-search/src/test/java/com/yahoo/search/grouping/request/parser/GroupingParserTestCase.java b/container-search/src/test/java/com/yahoo/search/grouping/request/parser/GroupingParserTestCase.java index 2c43873036e..afbad73f982 100644 --- a/container-search/src/test/java/com/yahoo/search/grouping/request/parser/GroupingParserTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/grouping/request/parser/GroupingParserTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.search.grouping.request.parser; import com.yahoo.search.grouping.request.AllOperation; +import com.yahoo.search.grouping.request.AttributeMapLookupValue; import com.yahoo.search.grouping.request.EachOperation; import com.yahoo.search.grouping.request.GroupingOperation; import com.yahoo.search.query.parser.Parsable; @@ -24,12 +25,6 @@ import static org.junit.Assert.fail; */ public class GroupingParserTestCase { - // -------------------------------------------------------------------------------- - // - // Tests. - // - // -------------------------------------------------------------------------------- - @Test public void requireThatMathAllowsWhitespace() { for (String op : Arrays.asList("+", " +", " + ", "+ ", @@ -448,6 +443,46 @@ public class GroupingParserTestCase { assertParse("all(group(my.little{key }))", "all(group(my.little{\"key\"}))"); assertParse("all(group(my.little{\"key\"}))", "all(group(my.little{\"key\"}))"); assertParse("all(group(my.little{\"key{}%\"}))", "all(group(my.little{\"key{}%\"}))"); + assertParse("all(group(my.little{key}.name))", "all(group(my.little{\"key\"}.name))"); + assertParse("all(group(my.little{key }.name))", "all(group(my.little{\"key\"}.name))"); + assertParse("all(group(my.little{\"key\"}.name))", "all(group(my.little{\"key\"}.name))"); + assertParse("all(group(my.little{\"key{}%\"}.name))", "all(group(my.little{\"key{}%\"}.name))"); + + assertAttributeMapLookup("all(group(my_map{\"my_key\"}))", + "my_map.key", "my_map.value", "my_key", ""); + assertAttributeMapLookup("all(group(my_map{\"my_key\"}.name))", + "my_map.key", "my_map.value.name", "my_key", ""); + assertAttributeMapLookup("all(group(my.map{\"my_key\"}))", + "my.map.key", "my.map.value", "my_key", ""); + } + + @Test + public void testMapSyntaxWithKeySourceAttribute() { + assertAttributeMapLookup("all(group(my_map{attribute(my_attr)}))", + "my_map.key", "my_map.value", "", "my_attr"); + assertAttributeMapLookup("all(group(my_map{attribute(my_attr)}.name))", + "my_map.key", "my_map.value.name", "", "my_attr"); + assertAttributeMapLookup("all(group(my.map{attribute(my_attr.name)}))", + "my.map.key", "my.map.value", "", "my_attr.name"); + + assertIllegalArgument("all(group(my_map{attribute(\"my_attr\")}))", + "Encountered \" <STRING> \"\\\"my_attr\\\" \"\" at line 1, column 28"); + + } + + private static void assertAttributeMapLookup(String request, + String expKeyAttribute, + String expValueAttribute, + String expKey, + String expKeySourceAttribute) { + assertParse(request, request); + List<GroupingOperation> operations = GroupingOperation.fromStringAsList(request); + assertEquals(1, operations.size()); + AttributeMapLookupValue mapLookup = (AttributeMapLookupValue)operations.get(0).getGroupBy(); + assertEquals(expKeyAttribute, mapLookup.getKeyAttribute()); + assertEquals(expValueAttribute, mapLookup.getValueAttribute()); + assertEquals(expKey, mapLookup.getKey()); + assertEquals(expKeySourceAttribute, mapLookup.getKeySourceAttribute()); } @Test diff --git a/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java b/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java index b8571aacca4..c64c4d624f2 100644 --- a/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java @@ -681,6 +681,24 @@ public class RequestBuilderTestCase { assertOutput(test); } + @Test + public void requireThatAttributeMapLookupNodeIsCreatedFromKey() { + RequestTest test = new RequestTest(); + test.expectedOutput = AttributeMapLookupNode.fromKey("map{\"my_key\"}", "map.key", "map.value", "my_key").toString(); + test.request = "all(group(map{\"my_key\"}) each(output(count())))"; + test.outputWriter = (groupingList, transform) -> groupingList.get(0).getLevels().get(0).getExpression().toString(); + assertOutput(test); + } + + @Test + public void requireThatAttributeMapLookupNodeIsCreatedFromKeySourceAttribute() { + RequestTest test = new RequestTest(); + test.expectedOutput = AttributeMapLookupNode.fromKeySourceAttribute("map{attribute(key_source)}", "map.key", "map.value", "key_source").toString(); + test.request = "all(group(map{attribute(key_source)}) each(output(count())))"; + test.outputWriter = (groupingList, transform) -> groupingList.get(0).getLevels().get(0).getExpression().toString(); + assertOutput(test); + } + private static CompositeContinuation newComposite(EncodableContinuation... conts) { CompositeContinuation ret = new CompositeContinuation(); for (EncodableContinuation cont : conts) { diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/BuildService.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/BuildService.java index 56c2ee8da6b..d0edbdcb8a6 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/BuildService.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/BuildService.java @@ -37,8 +37,6 @@ public interface BuildService { } - // TODO jvenstad: Argh, refactor this, considering the new JobId, etc.. - // TODO jvenstad: Probably: make jobName JobType instead. class BuildJob { private final ApplicationId applicationId; diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java index 54e057e4187..eb10c78f891 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.configserver; +import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions; import com.yahoo.vespa.hosted.controller.api.application.v4.model.EndpointStatus; import com.yahoo.vespa.hosted.controller.api.identifiers.DeploymentId; @@ -41,6 +42,7 @@ public interface ConfigServer { Map<?,?> getServiceApiResponse(String tenantName, String applicationName, String instanceName, String environment, String region, String serviceName, String restPath); + HttpResponse getLogs(DeploymentId deployment); /** * Set new status on en endpoint in one zone. * diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ServiceConvergence.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ServiceConvergence.java index 8a90224083b..6cfdc9fadc8 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ServiceConvergence.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ServiceConvergence.java @@ -1,35 +1,63 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.configserver; +import com.google.common.collect.ImmutableList; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.HostName; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; +import java.util.List; +import java.util.OptionalLong; + /** * Service convergence status for an application. * * @author mpolden + * @author jonmv */ public class ServiceConvergence { private final ApplicationId application; private final ZoneId zone; private final boolean converged; + private final long wantedGeneration; + private final List<Status> services; - public ServiceConvergence(ApplicationId application, ZoneId zone, boolean converged) { + public ServiceConvergence(ApplicationId application, ZoneId zone, boolean converged, + long wantedGeneration, List<Status> services) { this.application = application; this.zone = zone; this.converged = converged; + this.wantedGeneration = wantedGeneration; + this.services = ImmutableList.copyOf(services); } - public ApplicationId application() { - return application; - } + public ApplicationId application() { return application; } + public ZoneId zone() { return zone; } + public boolean converged() { return converged; } + public long wantedGeneration() { return wantedGeneration; } + public List<Status> services() { return services; } - public ZoneId zone() { - return zone; - } - public boolean converged() { - return converged; + /** Immutable class detailing the config status of a particular service for an application. */ + public static class Status { + private final HostName host; + private final long port; + private final String type; + private final long currentGeneration; + + public Status(HostName host, long port, String type, long currentGeneration) { + this.host = host; + this.port = port; + this.type = type; + this.currentGeneration = currentGeneration; + } + + public HostName host() { return host; } + public long port() { return port; } + public String type() { return type; } + public long currentGeneration() { return currentGeneration; } + } + } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/TesterCloud.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/TesterCloud.java index f2dbcda54b6..cb6332420ed 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/TesterCloud.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/TesterCloud.java @@ -21,6 +21,9 @@ public interface TesterCloud { /** Returns the current status of the tester. */ Status getStatus(URI testerUrl); + /** Returns whether the tester is ready to serve. */ + boolean ready(URI testerUrl); + enum Status { diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/MockOrganization.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/MockOrganization.java index 96ee9ecd052..8efbde52d4a 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/MockOrganization.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/MockOrganization.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.controller.api.integration.organization; import com.google.inject.Inject; +import com.yahoo.component.AbstractComponent; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; import java.net.URI; @@ -18,7 +19,7 @@ import java.util.concurrent.atomic.AtomicLong; /** * @author jvenstad */ -public class MockOrganization implements Organization { +public class MockOrganization extends AbstractComponent implements Organization { private final Clock clock; private final AtomicLong counter = new AtomicLong(); @@ -89,45 +90,58 @@ public class MockOrganization implements Organization { @Override public URI issueCreationUri(PropertyId propertyId) { - return URI.create("www.issues.tld/" + propertyId.id()); + return properties.getOrDefault(propertyId, new PropertyInfo()).issueUrl; } @Override public URI contactsUri(PropertyId propertyId) { - return URI.create("www.contacts.tld/" + propertyId.id()); + return properties.getOrDefault(propertyId, new PropertyInfo()).contactsUrl; } @Override public URI propertyUri(PropertyId propertyId) { - return URI.create("www.properties.tld/" + propertyId.id()); + return properties.getOrDefault(propertyId, new PropertyInfo()).propertyUrl; } public Map<IssueId, MockIssue> issues() { return Collections.unmodifiableMap(issues); } - public void close(IssueId issueId) { + public MockOrganization close(IssueId issueId) { issues.get(issueId).open = false; touch(issueId); + return this; } - public void setDefaultAssigneeFor(PropertyId propertyId, User defaultAssignee) { - properties.get(propertyId).defaultAssignee = defaultAssignee; + public MockOrganization setContactsFor(PropertyId propertyId, List<List<User>> contacts) { + properties.get(propertyId).contacts = contacts; + return this; } - public void setContactsFor(PropertyId propertyId, List<List<User>> contacts) { - properties.get(propertyId).contacts = contacts; + public MockOrganization setPropertyUrl(PropertyId propertyId, URI url) { + properties.get(propertyId).propertyUrl = url; + return this; + } + + public MockOrganization setContactsUrl(PropertyId propertyId, URI url) { + properties.get(propertyId).contactsUrl = url; + return this; } - public void addProperty(PropertyId propertyId) { + public MockOrganization setIssueUrl(PropertyId propertyId, URI url) { + properties.get(propertyId).issueUrl = url; + return this; + } + + public MockOrganization addProperty(PropertyId propertyId) { properties.put(propertyId, new PropertyInfo()); + return this; } private void touch(IssueId issueId) { issues.get(issueId).updated = clock.instant(); } - public class MockIssue { private Issue issue; @@ -148,11 +162,13 @@ public class MockOrganization implements Organization { } - private class PropertyInfo { private User defaultAssignee; private List<List<User>> contacts = Collections.emptyList(); + private URI issueUrl; + private URI contactsUrl; + private URI propertyUrl; } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/MockTesterCloud.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/MockTesterCloud.java index 176fa8ae683..2f2eff99a4f 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/MockTesterCloud.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/MockTesterCloud.java @@ -37,6 +37,11 @@ public class MockTesterCloud implements TesterCloud { return status; } + @Override + public boolean ready(URI resterUrl) { + return true; + } + public void add(LogEntry entry) { log.add(entry); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java index e984edca7db..af5b9198343 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java @@ -27,6 +27,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.configserver.NoInstance import com.yahoo.vespa.hosted.controller.api.integration.configserver.PrepareResponse; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationStore; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ArtifactRepository; +import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.api.integration.dns.NameService; import com.yahoo.vespa.hosted.controller.api.integration.dns.Record; import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordData; @@ -38,11 +39,11 @@ import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Deployment; -import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.application.JobList; import com.yahoo.vespa.hosted.controller.application.JobStatus; import com.yahoo.vespa.hosted.controller.application.JobStatus.JobRun; import com.yahoo.vespa.hosted.controller.application.SystemApplication; +import com.yahoo.vespa.hosted.controller.concurrent.Once; import com.yahoo.vespa.hosted.controller.deployment.DeploymentTrigger; import com.yahoo.vespa.hosted.controller.persistence.CuratorDb; import com.yahoo.vespa.hosted.controller.rotation.Rotation; @@ -57,6 +58,8 @@ import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.time.Clock; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -115,9 +118,17 @@ public class ApplicationController { this.rotationRepository = new RotationRepository(rotationsConfig, this, curator); this.deploymentTrigger = new DeploymentTrigger(controller, buildService, clock); - for (Application application : curator.readApplications()) { - lockIfPresent(application.id(), this::store); - } + // Update serialization format of all applications + Once.after(Duration.ofMinutes(1), () -> { + Instant start = clock.instant(); + int count = 0; + for (Application application : curator.readApplications()) { + lockIfPresent(application.id(), this::store); + count++; + } + log.log(Level.INFO, String.format("Wrote %d applications in %s", count, + Duration.between(start, clock.instant()))); + }); } /** Returns the application with the given id, or null if it is not present */ @@ -240,7 +251,9 @@ public class ApplicationController { */ public Application createApplication(ApplicationId id, Optional<NToken> token) { if ( ! (id.instance().isDefault())) // TODO: Support instances properly - throw new UnsupportedOperationException("Only the instance name 'default' is supported at the moment"); + throw new IllegalArgumentException("Only the instance name 'default' is supported at the moment"); + if (id.instance().isTester()) + throw new IllegalArgumentException("'" + id + "' is a tester application!"); try (Lock lock = lock(id)) { // Validate only application names which do not already exist. if (asList(id.tenant()).stream().noneMatch(application -> application.id().application().equals(id.application()))) @@ -270,9 +283,13 @@ public class ApplicationController { /** Deploys an application. If the application does not exist it is created. */ // TODO: Get rid of the options arg + // TODO jvenstad: Split this, and choose between deployDirectly and deploy in handler, excluding internally built from the latter. public ActivateResult deploy(ApplicationId applicationId, ZoneId zone, Optional<ApplicationPackage> applicationPackageFromDeployer, DeployOptions options) { + if (applicationId.instance().isTester()) + throw new IllegalArgumentException("'" + applicationId + "' is a tester application!"); + try (Lock lock = lock(applicationId)) { LockedApplication application = get(applicationId) .map(app -> new LockedApplication(app, lock)) @@ -375,7 +392,7 @@ public class ApplicationController { /** Assembles and deploys a tester application to the given zone. */ public ActivateResult deployTester(ApplicationId tester, ApplicationPackage applicationPackage, ZoneId zone, DeployOptions options) { - if ( ! tester.instance().value().endsWith("-t")) + if ( ! tester.instance().isTester()) throw new IllegalArgumentException("'" + tester + "' is not a tester application!"); return deploy(tester, applicationPackage, zone, options, Collections.emptySet(), Collections.emptySet()); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java index ee0a6875796..1576ab597be 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java @@ -12,8 +12,8 @@ import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; import com.yahoo.vespa.hosted.controller.api.integration.BuildService; -import com.yahoo.vespa.hosted.controller.api.integration.RunDataStore; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService; +import com.yahoo.vespa.hosted.controller.api.integration.RunDataStore; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory; import com.yahoo.vespa.hosted.controller.api.integration.chef.Chef; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServer; @@ -79,7 +79,6 @@ public class Controller extends AbstractComponent { private final ConfigServer configServer; private final MetricsService metricsService; private final Chef chef; - private final Organization organization; private final AthenzClientFactory athenzClientFactory; /** @@ -97,14 +96,14 @@ public class Controller extends AbstractComponent { ArtifactRepository artifactRepository, ApplicationStore applicationStore, TesterCloud testerCloud, BuildService buildService, RunDataStore runDataStore) { this(curator, rotationsConfig, - gitHub, entityService, organization, globalRoutingService, zoneRegistry, + gitHub, entityService, globalRoutingService, zoneRegistry, configServer, metricsService, nameService, routingGenerator, chef, Clock.systemUTC(), athenzClientFactory, artifactRepository, applicationStore, testerCloud, buildService, runDataStore, com.yahoo.net.HostName::getLocalhost); } public Controller(CuratorDb curator, RotationsConfig rotationsConfig, - GitHub gitHub, EntityService entityService, Organization organization, + GitHub gitHub, EntityService entityService, GlobalRoutingService globalRoutingService, ZoneRegistry zoneRegistry, ConfigServer configServer, MetricsService metricsService, NameService nameService, @@ -117,7 +116,6 @@ public class Controller extends AbstractComponent { this.curator = Objects.requireNonNull(curator, "Curator cannot be null"); this.gitHub = Objects.requireNonNull(gitHub, "GitHub cannot be null"); this.entityService = Objects.requireNonNull(entityService, "EntityService cannot be null"); - this.organization = Objects.requireNonNull(organization, "Organization cannot be null"); this.globalRoutingService = Objects.requireNonNull(globalRoutingService, "GlobalRoutingService cannot be null"); this.zoneRegistry = Objects.requireNonNull(zoneRegistry, "ZoneRegistry cannot be null"); this.configServer = Objects.requireNonNull(configServer, "ConfigServer cannot be null"); @@ -289,10 +287,6 @@ public class Controller extends AbstractComponent { return chef; } - public Organization organization() { - return organization; - } - public CuratorDb curator() { return curator; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedTenant.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedTenant.java new file mode 100644 index 00000000000..cb3f50d08c7 --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedTenant.java @@ -0,0 +1,76 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller; + +import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.athenz.api.AthenzDomain; +import com.yahoo.vespa.curator.Lock; +import com.yahoo.vespa.hosted.controller.api.identifiers.Property; +import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; +import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; + +import java.util.Objects; +import java.util.Optional; +import java.util.function.Consumer; + +/** + * A tenant that has been locked for modification. Provides methods for modifying a tenant's fields. + * + * @author mpolden + */ +public class LockedTenant { + + private final Lock lock; + private final TenantName name; + private final AthenzDomain domain; + private final Property property; + private final Optional<PropertyId> propertyId; + private final Optional<Contact> contact; + + /** + * Should never be constructed directly. + * + * Use {@link TenantController#lockIfPresent(TenantName, Consumer)} or + * {@link TenantController#lockOrThrow(TenantName, Consumer)} + */ + LockedTenant(AthenzTenant tenant, Lock lock) { + this(lock, tenant.name(), tenant.domain(), tenant.property(), tenant.propertyId(), tenant.contact()); + } + + private LockedTenant(Lock lock, TenantName name, AthenzDomain domain, Property property, + Optional<PropertyId> propertyId, Optional<Contact> contact) { + this.lock = Objects.requireNonNull(lock, "lock must be non-null"); + this.name = Objects.requireNonNull(name, "name must be non-null"); + this.domain = Objects.requireNonNull(domain, "domain must be non-null"); + this.property = Objects.requireNonNull(property, "property must be non-null"); + this.propertyId = Objects.requireNonNull(propertyId, "propertyId must be non-null"); + this.contact = Objects.requireNonNull(contact, "contact must be non-null"); + } + + /** Returns a read-only copy of this */ + public AthenzTenant get() { + return new AthenzTenant(name, domain, property, propertyId, contact); + } + + public LockedTenant with(AthenzDomain domain) { + return new LockedTenant(lock, name, domain, property, propertyId, contact); + } + + public LockedTenant with(Property property) { + return new LockedTenant(lock, name, domain, property, propertyId, contact); + } + + public LockedTenant with(PropertyId propertyId) { + return new LockedTenant(lock, name, domain, property, Optional.of(propertyId), contact); + } + + public LockedTenant with(Contact contact) { + return new LockedTenant(lock, name, domain, property, propertyId, Optional.of(contact)); + } + + @Override + public String toString() { + return "tenant '" + name + "'"; + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java index 228ca01e764..1ae3e6a6577 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java @@ -10,17 +10,22 @@ import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.hosted.controller.api.identifiers.UserId; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory; import com.yahoo.vespa.hosted.controller.api.integration.athenz.ZmsClient; +import com.yahoo.vespa.hosted.controller.concurrent.Once; import com.yahoo.vespa.hosted.controller.persistence.CuratorDb; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; import com.yahoo.vespa.hosted.controller.tenant.Tenant; import com.yahoo.vespa.hosted.controller.tenant.UserTenant; import java.time.Duration; +import java.time.Instant; import java.util.Comparator; import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Consumer; +import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -34,30 +39,39 @@ public class TenantController { private static final Logger log = Logger.getLogger(TenantController.class.getName()); - /** The controller owning this */ private final Controller controller; - - /** For persistence */ private final CuratorDb curator; - private final AthenzClientFactory athenzClientFactory; public TenantController(Controller controller, CuratorDb curator, AthenzClientFactory athenzClientFactory) { - this.controller = controller; - this.curator = curator; - this.athenzClientFactory = athenzClientFactory; - // Write all tenants to ensure persisted data uses latest serialization format - for (Tenant tenant : curator.readTenants()) { - try (Lock lock = lock(tenant.name())) { - if (tenant instanceof AthenzTenant) { - curator.writeTenant((AthenzTenant) tenant); - } else if (tenant instanceof UserTenant) { - curator.writeTenant((UserTenant) tenant); - } else { - throw new IllegalArgumentException("Unknown tenant type: " + tenant.getClass().getSimpleName()); + this.controller = Objects.requireNonNull(controller, "controller must be non-null"); + this.curator = Objects.requireNonNull(curator, "curator must be non-null"); + this.athenzClientFactory = Objects.requireNonNull(athenzClientFactory, "athenzClientFactory must be non-null"); + + // Update serialization format of all tenants + Once.after(Duration.ofMinutes(1), () -> { + Instant start = controller.clock().instant(); + int count = 0; + for (TenantName name : curator.readTenantNames()) { + try (Lock lock = lock(name)) { + // Get while holding lock so that we know we're operating on a current version + Optional<Tenant> optionalTenant = tenant(name); + if (!optionalTenant.isPresent()) continue; // Deleted while updating, skip + + Tenant tenant = optionalTenant.get(); + if (tenant instanceof AthenzTenant) { + curator.writeTenant((AthenzTenant) tenant); + } else if (tenant instanceof UserTenant) { + curator.writeTenant((UserTenant) tenant); + } else { + throw new IllegalArgumentException("Unknown tenant type: " + tenant.getClass().getSimpleName()); + } } + count++; } - } + log.log(Level.INFO, String.format("Wrote %d tenants in %s", count, + Duration.between(start, controller.clock().instant()))); + }); } /** Returns a list of all known tenants sorted by name */ @@ -79,12 +93,33 @@ public class TenantController { } } + /** + * Lock a tenant for modification and apply action. Only valid for Athenz tenants as it's the only type that + * accepts modification. + */ + public void lockIfPresent(TenantName name, Consumer<LockedTenant> action) { + try (Lock lock = lock(name)) { + athenzTenant(name).map(tenant -> new LockedTenant(tenant, lock)).ifPresent(action); + } + } + + /** Lock a tenant for modification and apply action. Throws if the tenant does not exist */ + public void lockOrThrow(TenantName name, Consumer<LockedTenant> action) { + try (Lock lock = lock(name)) { + action.accept(new LockedTenant(requireAthenzTenant(name), lock)); + } + } + + /** Replace and store any previous version of given tenant */ + public void store(LockedTenant tenant) { + curator.writeTenant(tenant.get()); + } + /** Create an user tenant with given username */ public void create(UserTenant tenant) { try (Lock lock = lock(tenant.name())) { requireNonExistent(tenant.name()); curator.writeTenant(tenant); - log.info("Created " + tenant); } } @@ -103,7 +138,6 @@ public class TenantController { } athenzClientFactory.createZmsClientWithAuthorizedServiceToken(token).createTenant(domain); curator.writeTenant(tenant); - log.info("Created " + tenant); } } @@ -129,14 +163,29 @@ public class TenantController { return curator.readAthenzTenant(name); } - /** Update Athenz tenant */ - public void updateTenant(AthenzTenant updatedTenant, NToken token) { - try (Lock lock = lock(updatedTenant.name())) { - requireExists(updatedTenant.name()); - updateAthenzDomain(updatedTenant, token); - curator.writeTenant(updatedTenant); - log.info("Updated " + updatedTenant); - } + /** Returns Athenz tenant with name or throws if no such tenant exists */ + public AthenzTenant requireAthenzTenant(TenantName name) { + return athenzTenant(name).orElseThrow(() -> new IllegalArgumentException("Tenant '" + name + "' not found")); + } + + /** Update Athenz domain for tenant. Returns the updated tenant which must be explicitly stored */ + public LockedTenant withDomain(LockedTenant tenant, AthenzDomain newDomain, NToken token) { + AthenzDomain existingDomain = tenant.get().domain(); + if (existingDomain.equals(newDomain)) return tenant; + Optional<Tenant> existingTenantWithNewDomain = tenantIn(newDomain); + if (existingTenantWithNewDomain.isPresent()) + throw new IllegalArgumentException("Could not set domain of " + tenant + " to '" + newDomain + + "':" + existingTenantWithNewDomain.get() + " already has this domain"); + + ZmsClient zmsClient = athenzClientFactory.createZmsClientWithAuthorizedServiceToken(token); + zmsClient.createTenant(newDomain); + List<Application> applications = controller.applications().asList(tenant.get().name()); + applications.forEach(a -> zmsClient.addApplication(newDomain, new com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId(a.id().application().value()))); + applications.forEach(a -> zmsClient.deleteApplication(existingDomain, new com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId(a.id().application().value()))); + zmsClient.deleteTenant(existingDomain); + log.info("Set Athenz domain for '" + tenant + "' from '" + existingDomain + "' to '" + newDomain + "'"); + + return tenant.with(newDomain); } /** Delete an user tenant */ @@ -160,28 +209,6 @@ public class TenantController { + "': This tenant has active applications"); } curator.removeTenant(name); - log.info("Deleted " + name); - } - - private void updateAthenzDomain(AthenzTenant updatedTenant, NToken token) { - Optional<AthenzTenant> existingTenant = athenzTenant(updatedTenant.name()); - if ( ! existingTenant.isPresent()) return; - - AthenzDomain existingDomain = existingTenant.get().domain(); - AthenzDomain newDomain = updatedTenant.domain(); - if (existingDomain.equals(newDomain)) return; - Optional<Tenant> existingTenantWithNewDomain = tenantIn(newDomain); - if (existingTenantWithNewDomain.isPresent()) - throw new IllegalArgumentException("Could not set domain of " + updatedTenant + " to '" + newDomain + - "':" + existingTenantWithNewDomain.get() + " already has this domain"); - - ZmsClient zmsClient = athenzClientFactory.createZmsClientWithAuthorizedServiceToken(token); - zmsClient.createTenant(newDomain); - List<Application> applications = controller.applications().asList(existingTenant.get().name()); - applications.forEach(a -> zmsClient.addApplication(newDomain, new com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId(a.id().application().value()))); - applications.forEach(a -> zmsClient.deleteApplication(existingDomain, new com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId(a.id().application().value()))); - zmsClient.deleteTenant(existingDomain); - log.info("Updated Athens domain for " + updatedTenant + " from " + existingDomain + " to " + newDomain); } private void requireNonExistent(TenantName name) { @@ -193,12 +220,6 @@ public class TenantController { } } - private void requireExists(TenantName name) { - if (!tenant(name).isPresent()) { - throw new IllegalArgumentException("Tenant '" + name + "' does not exist"); - } - } - /** * Returns a lock which provides exclusive rights to changing this tenant. * Any operation which stores a tenant need to first acquire this lock, then read, modify diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java index b3dd46a3e65..703a198be1e 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.controller.application; import java.util.Objects; import java.util.Optional; +import java.util.OptionalLong; /** * An application package version, identified by a source revision and a build number. @@ -16,21 +17,21 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> { * Used in cases where application version cannot be determined, such as manual deployments (e.g. in dev * environment) */ - public static final ApplicationVersion unknown = new ApplicationVersion(Optional.empty(), Optional.empty()); + public static final ApplicationVersion unknown = new ApplicationVersion(Optional.empty(), OptionalLong.empty()); // This never changes and is only used to create a valid semantic version number, as required by application bundles private static final String majorVersion = "1.0"; private final Optional<SourceRevision> source; - private final Optional<Long> buildNumber; + private final OptionalLong buildNumber; - private ApplicationVersion(Optional<SourceRevision> source, Optional<Long> buildNumber) { + private ApplicationVersion(Optional<SourceRevision> source, OptionalLong buildNumber) { Objects.requireNonNull(source, "source cannot be null"); Objects.requireNonNull(buildNumber, "buildNumber cannot be null"); if (source.isPresent() != buildNumber.isPresent()) { throw new IllegalArgumentException("both buildNumber and source must be set together"); } - if (buildNumber.isPresent() && buildNumber.get() <= 0) { + if (buildNumber.isPresent() && buildNumber.getAsLong() <= 0) { throw new IllegalArgumentException("buildNumber must be > 0"); } this.source = source; @@ -39,7 +40,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> { /** Create an application package version from a completed build */ public static ApplicationVersion from(SourceRevision source, long buildNumber) { - return new ApplicationVersion(Optional.of(source), Optional.of(buildNumber)); + return new ApplicationVersion(Optional.of(source), OptionalLong.of(buildNumber)); } /** Returns an unique identifier for this version or "unknown" if version is not known */ @@ -47,7 +48,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> { if (isUnknown()) { return "unknown"; } - return String.format("%s.%d-%s", majorVersion, buildNumber.get(), abbreviateCommit(source.get().commit())); + return String.format("%s.%d-%s", majorVersion, buildNumber.getAsLong(), abbreviateCommit(source.get().commit())); } /** @@ -57,7 +58,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> { public Optional<SourceRevision> source() { return source; } /** Returns the build number that built this version */ - public Optional<Long> buildNumber() { return buildNumber; } + public OptionalLong buildNumber() { return buildNumber; } /** Returns whether this is unknown */ public boolean isUnknown() { @@ -93,6 +94,6 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> { if ( ! buildNumber().isPresent() || ! o.buildNumber().isPresent()) return Boolean.compare(buildNumber().isPresent(), o.buildNumber.isPresent()); // Application package hash sorts first - return buildNumber().get().compareTo(o.buildNumber().get()); + return Long.compare(buildNumber().getAsLong(), o.buildNumber().getAsLong()); } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java index 0a062427a8a..a2433d223dc 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java @@ -68,12 +68,12 @@ public class Deployment { public DeploymentActivity activity() { return activity; } /** Returns information about the clusters allocated to this */ - public Map<Id, ClusterInfo> clusterInfo() { + public Map<Id, ClusterInfo> clusterInfo() { return clusterInfo; } /** Returns utilization of the clusters allocated to this */ - public Map<Id, ClusterUtilization> clusterUtils() { + public Map<Id, ClusterUtilization> clusterUtils() { return clusterUtils; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/SystemApplication.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/SystemApplication.java index cc4f236f3b2..5a57394ff6b 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/SystemApplication.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/SystemApplication.java @@ -5,7 +5,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.NodeType; -import com.yahoo.config.provision.RegionName; import com.yahoo.vespa.hosted.controller.Controller; import com.yahoo.vespa.hosted.controller.api.identifiers.DeploymentId; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ServiceConvergence; @@ -68,13 +67,17 @@ public enum SystemApplication { if (!hasApplicationPackage()) { return true; } - // TODO: Remove this hack once Docker hosts are removed from zone-application. - if (isAws(zone.region())) { - return true; // Skip checking config convergence on AWS as Docker hosts do not have cloud config - } + // TODO: Docker hosts running host admin cannot be checked. Since a zone can have + // Docker hosts running either host admin or node-admin, it's not possible to check + // config convergence, so we need to always return true here. + // We want to remove the line below and check config convergence for proxy nodes + // when all Docker hosts are running host admin + return true; + /* return controller.configServer().serviceConvergence(new DeploymentId(id(), zone)) .map(ServiceConvergence::converged) .orElse(false); + */ } /** Returns the node types of this that should receive OS upgrades */ @@ -92,8 +95,4 @@ public enum SystemApplication { return String.format("system application %s of type %s", id, nodeTypes); } - private static boolean isAws(RegionName region) { - return region.value().startsWith("cd-aws-") || region.value().startsWith("aws-"); - } - } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzTrustStoreConfigurator.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzTrustStoreConfigurator.java deleted file mode 100644 index 909104d1731..00000000000 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzTrustStoreConfigurator.java +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.athenz.filter; - -import com.google.inject.Inject; -import com.yahoo.jdisc.http.ssl.SslTrustStoreConfigurator; -import com.yahoo.jdisc.http.ssl.SslTrustStoreContext; -import com.yahoo.vespa.athenz.tls.KeyStoreBuilder; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.hosted.controller.athenz.config.AthenzConfig; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.cert.CertificateException; - -/** - * Load trust store with Athenz CA certificates - * - * @author bjorncs - */ -public class AthenzTrustStoreConfigurator implements SslTrustStoreConfigurator { - - private final KeyStore trustStore; - - @Inject - public AthenzTrustStoreConfigurator(AthenzConfig config) { - this.trustStore = createTrustStore(new File(config.athenzCaTrustStore())); - } - - private static KeyStore createTrustStore(File trustStoreFile) { - return KeyStoreBuilder.withType(KeyStoreType.JKS) - .fromFile(trustStoreFile, "changeit".toCharArray()) - .build(); - } - - @Override - public void configure(SslTrustStoreContext context) { - context.updateTrustStore(trustStore); - } -} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/concurrent/Once.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/concurrent/Once.java new file mode 100644 index 00000000000..81ddd8d2d70 --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/concurrent/Once.java @@ -0,0 +1,46 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.concurrent; + +import java.time.Duration; +import java.util.Objects; +import java.util.Timer; +import java.util.TimerTask; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Execute a runnable exactly once in a background thread. + * + * @author mpolden + */ +public class Once extends TimerTask { + + private static final Logger log = Logger.getLogger(Once.class.getName()); + + private final Runnable runnable; + private final Timer timer = new Timer(true); + + // private to avoid exposing run method + private Once(Runnable runnable, Duration delay) { + this.runnable = Objects.requireNonNull(runnable, "runnable must be non-null"); + Objects.requireNonNull(delay, "delay must be non-null"); + timer.schedule(this, delay.toMillis()); + } + + /** Execute runnable after given delay */ + public static void after(Duration delay, Runnable runnable) { + new Once(runnable, delay); + } + + @Override + public void run() { + try { + runnable.run(); + } catch (Throwable t) { + log.log(Level.WARNING, "Task '" + runnable + "' failed", t); + } finally { + timer.cancel(); + } + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java index becef782519..2284b82bcfb 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java @@ -412,10 +412,13 @@ public class DeploymentTrigger { */ public boolean isComplete(Change change, Application application, JobType jobType) { Optional<Deployment> existingDeployment = deploymentFor(application, jobType); - return successOn(application, jobType, Versions.from(change, application, existingDeployment, controller.systemVersion())).isPresent() + return application.deploymentJobs().statusOf(jobType).flatMap(JobStatus::lastSuccess) + .map(job -> change.platform().map(job.platform()::equals).orElse(true) + && change.application().map(job.application()::equals).orElse(true)) + .orElse(false) || jobType.isProduction() && existingDeployment.map(deployment -> ! isUpgrade(change, deployment) && isDowngrade(application.change(), deployment)) - .orElse(false); + .orElse(false); } private static boolean isUpgrade(Change change, Deployment deployment) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java index b49722f2f2d..2bb878366c5 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java @@ -40,11 +40,9 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.List; import java.util.Map; -import java.util.NoSuchElementException; import java.util.Optional; import java.util.function.Supplier; import java.util.logging.Level; -import java.util.logging.LogRecord; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -238,7 +236,7 @@ public class InternalStepRunner implements StepRunner { ApplicationVersion application = setTheStage ? versions.sourceApplication().orElse(versions.targetApplication()) : versions.targetApplication(); logger.log("Checking installation of " + platform + " and " + application.id() + " ..."); - if (nodesConverged(id.application(), id.type(), platform, logger) && servicesConverged(id.application(), id.type())) { + if (nodesConverged(id.application(), id.type(), platform, logger) && servicesConverged(id.application(), id.type(), logger)) { logger.log("Installation succeeded!"); return Optional.of(running); } @@ -260,7 +258,7 @@ public class InternalStepRunner implements StepRunner { } logger.log("Checking installation of tester container ..."); - if (servicesConverged(JobController.testerOf(id.application()), id.type())) { + if (servicesConverged(JobController.testerOf(id.application()), id.type(), logger)) { logger.log("Tester container successfully installed!"); return Optional.of(running); } @@ -291,11 +289,21 @@ public class InternalStepRunner implements StepRunner { && node.rebootGeneration() == node.wantedRebootGeneration()); } - private boolean servicesConverged(ApplicationId id, JobType type) { - // TODO jvenstad: Print information for each host. - return controller.configServer().serviceConvergence(new DeploymentId(id, type.zone(controller.system()))) - .map(ServiceConvergence::converged) - .orElse(false); + private boolean servicesConverged(ApplicationId id, JobType type, DualLogger logger) { + Optional<ServiceConvergence> convergence = controller.configServer().serviceConvergence(new DeploymentId(id, type.zone(controller.system()))); + if ( ! convergence.isPresent()) { + logger.log("Config status not currently available -- will retry."); + return false; + } + logger.log("Wanted config generation is " + convergence.get().wantedGeneration()); + for (ServiceConvergence.Status serviceStatus : convergence.get().services()) + if (serviceStatus.currentGeneration() != convergence.get().wantedGeneration()) + logger.log(String.format("%70s: %11s on port %4d has %s", + serviceStatus.host().value(), + serviceStatus.type(), + serviceStatus.port(), + serviceStatus.currentGeneration() == -1 ? "(unknown)" : Long.toString(serviceStatus.currentGeneration()))); + return convergence.get().converged(); } private Optional<RunStatus> startTests(RunId id, DualLogger logger) { @@ -325,7 +333,7 @@ public class InternalStepRunner implements StepRunner { } Optional<URI> testerEndpoint = controller.jobController().testerEndpoint(id); - if (testerEndpoint.isPresent()) { + if (testerEndpoint.isPresent() && controller.jobController().cloud().ready(testerEndpoint.get())) { logger.log("Starting tests ..."); controller.jobController().cloud().startTests(testerEndpoint.get(), TesterCloud.Suite.of(id.type()), @@ -348,13 +356,16 @@ public class InternalStepRunner implements StepRunner { return Optional.of(aborted); } - URI testerEndpoint = controller.jobController().testerEndpoint(id) - .orElseThrow(() -> new NoSuchElementException("Endpoint for tester vanished again before tests were complete!")); + Optional<URI> testerEndpoint = controller.jobController().testerEndpoint(id); + if ( ! testerEndpoint.isPresent()) { + logger.log("Endpoints for tester not found -- trying again later."); + return Optional.empty(); + } controller.jobController().updateTestLog(id); RunStatus status; - TesterCloud.Status testStatus = controller.jobController().cloud().getStatus(testerEndpoint); + TesterCloud.Status testStatus = controller.jobController().cloud().getStatus(testerEndpoint.get()); switch (testStatus) { case NOT_STARTED: throw new IllegalStateException("Tester reports tests not started, even though they should have!"); @@ -491,7 +502,7 @@ public class InternalStepRunner implements StepRunner { " </filtering>\n" + " </http>\n" + "\n" + - " <nodes count=\"1\" flavor=\"d-2-8-50\" />\n" + + " <nodes count=\"1\" flavor=\"d-1-4-50\" />\n" + " </container>\n" + "</services>\n"; @@ -549,16 +560,13 @@ public class InternalStepRunner implements StepRunner { } private void log(Level level, String message, Throwable thrown) { - LogRecord record = new LogRecord(level, prefix + message); - record.setThrown(thrown); - logger.log(record); + logger.log(level, message, thrown); if (thrown != null) { ByteArrayOutputStream traceBuffer = new ByteArrayOutputStream(); thrown.printStackTrace(new PrintStream(traceBuffer)); message += "\n" + traceBuffer; } - controller.jobController().log(id, step, level, message); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobProfile.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobProfile.java index 0cad9e98d5d..f7794747db9 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobProfile.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobProfile.java @@ -15,7 +15,6 @@ import static com.yahoo.vespa.hosted.controller.deployment.Step.*; */ public enum JobProfile { - // TODO jvenstad: runTests is not a run-always step, as it really means: check if tests are done, and store whatever is ready. systemTest(EnumSet.of(deployReal, installReal, deployTester, diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainer.java new file mode 100644 index 00000000000..5a825bf7b85 --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainer.java @@ -0,0 +1,71 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.maintenance; + +import com.yahoo.log.LogLevel; +import com.yahoo.vespa.hosted.controller.Controller; +import com.yahoo.vespa.hosted.controller.api.integration.organization.Organization; +import com.yahoo.vespa.hosted.controller.api.integration.organization.User; +import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; +import com.yahoo.vespa.hosted.controller.tenant.Tenant; +import com.yahoo.yolean.Exceptions; + +import java.time.Duration; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +/** + * Periodically fetch and store contact information for tenants. + * + * @author mpolden + */ +public class ContactInformationMaintainer extends Maintainer { + + private static final Logger log = Logger.getLogger(ContactInformationMaintainer.class.getName()); + + private final Organization organization; + + public ContactInformationMaintainer(Controller controller, Duration interval, JobControl jobControl, Organization organization) { + super(controller, interval, jobControl); + this.organization = Objects.requireNonNull(organization, "organization must be non-null"); + } + + @Override + protected void maintain() { + for (Tenant t : controller().tenants().asList()) { + if (!(t instanceof AthenzTenant)) continue; // No contact information for non-Athenz tenants + AthenzTenant tenant = (AthenzTenant) t; + if (!tenant.propertyId().isPresent()) continue; // Can only update contact information if property ID is known + try { + findContact(tenant).ifPresent(contact -> { + controller().tenants().lockIfPresent(t.name(), lockedTenant -> controller().tenants().store(lockedTenant.with(contact))); + }); + } catch (Exception e) { + log.log(LogLevel.WARNING, "Failed to update contact information for " + tenant + ": " + + Exceptions.toMessageString(e) + ". Retrying in " + + maintenanceInterval()); + } + } + } + + /** Find contact information for given tenant */ + private Optional<Contact> findContact(AthenzTenant tenant) { + if (!tenant.propertyId().isPresent()) { + return Optional.empty(); + } + List<List<String>> persons = organization.contactsFor(tenant.propertyId().get()) + .stream() + .map(personList -> personList.stream() + .map(User::displayName) + .collect(Collectors.toList())) + .collect(Collectors.toList()); + return Optional.of(new Contact(organization.contactsUri(tenant.propertyId().get()), + organization.propertyUri(tenant.propertyId().get()), + organization.issueCreationUri(tenant.propertyId().get()), + persons)); + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java index 2c65ea0e3cb..c67eab8826e 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java @@ -5,13 +5,12 @@ import com.yahoo.component.AbstractComponent; import com.yahoo.jdisc.Metric; import com.yahoo.vespa.hosted.controller.Controller; import com.yahoo.vespa.hosted.controller.api.integration.chef.Chef; -import com.yahoo.vespa.hosted.controller.api.integration.deployment.TesterCloud; import com.yahoo.vespa.hosted.controller.api.integration.dns.NameService; import com.yahoo.vespa.hosted.controller.api.integration.noderepository.NodeRepositoryClientInterface; import com.yahoo.vespa.hosted.controller.api.integration.organization.DeploymentIssues; +import com.yahoo.vespa.hosted.controller.api.integration.organization.Organization; import com.yahoo.vespa.hosted.controller.api.integration.organization.OwnershipIssues; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; -import com.yahoo.vespa.hosted.controller.deployment.InternalStepRunner; import com.yahoo.vespa.hosted.controller.maintenance.config.MaintainerConfig; import com.yahoo.vespa.hosted.controller.persistence.CuratorDb; @@ -47,12 +46,14 @@ public class ControllerMaintenance extends AbstractComponent { private final List<OsUpgrader> osUpgraders; private final OsVersionStatusUpdater osVersionStatusUpdater; private final JobRunner jobRunner; + private final ContactInformationMaintainer contactInformationMaintainer; @SuppressWarnings("unused") // instantiated by Dependency Injection public ControllerMaintenance(MaintainerConfig maintainerConfig, Controller controller, CuratorDb curator, JobControl jobControl, Metric metric, Chef chefClient, DeploymentIssues deploymentIssues, OwnershipIssues ownershipIssues, - NameService nameService, NodeRepositoryClientInterface nodeRepositoryClient) { + NameService nameService, NodeRepositoryClientInterface nodeRepositoryClient, + Organization organization) { Duration maintenanceInterval = Duration.ofMinutes(maintainerConfig.intervalMinutes()); this.jobControl = jobControl; deploymentExpirer = new DeploymentExpirer(controller, maintenanceInterval, jobControl); @@ -71,6 +72,7 @@ public class ControllerMaintenance extends AbstractComponent { jobRunner = new JobRunner(controller, Duration.ofSeconds(30), jobControl); osUpgraders = osUpgraders(controller, jobControl); osVersionStatusUpdater = new OsVersionStatusUpdater(controller, maintenanceInterval, jobControl); + contactInformationMaintainer = new ContactInformationMaintainer(controller, Duration.ofHours(12), jobControl, organization); } public Upgrader upgrader() { return upgrader; } @@ -96,6 +98,7 @@ public class ControllerMaintenance extends AbstractComponent { osUpgraders.forEach(Maintainer::deconstruct); osVersionStatusUpdater.deconstruct(); jobRunner.deconstruct(); + contactInformationMaintainer.deconstruct(); } /** Create one OS upgrader per cloud found in the zone registry of controller */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/OsUpgrader.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/OsUpgrader.java index bf9fbeb26d3..7f3b2400736 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/OsUpgrader.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/OsUpgrader.java @@ -3,7 +3,6 @@ package com.yahoo.vespa.hosted.controller.maintenance; import com.google.common.collect.ImmutableSet; import com.yahoo.component.Version; -import com.yahoo.config.provision.SystemName; import com.yahoo.vespa.hosted.controller.Controller; import com.yahoo.vespa.hosted.controller.api.integration.configserver.Node; import com.yahoo.vespa.hosted.controller.api.integration.zone.CloudName; @@ -39,15 +38,11 @@ public class OsUpgrader extends InfrastructureUpgrader { } @Override - protected void maintain() { - if (controller().system() != SystemName.cd) return; // TODO: Enable in all systems - super.maintain(); - } - - @Override protected void upgrade(Version target, SystemApplication application, ZoneId zone) { + if (wantedVersion(zone, application, target).equals(target)) { + return; + } log.info(String.format("Upgrading OS of %s to version %s in %s", application.id(), target, zone)); - // Node repository ensures the upgrade call is idempotent application.nodeTypesWithUpgradableOs().forEach(nodeType -> controller().configServer().nodeRepository() .upgradeOs(zone, nodeType, target)); } @@ -75,6 +70,10 @@ public class OsUpgrader extends InfrastructureUpgrader { return minVersion(zone, application, Node::currentOsVersion).orElse(defaultVersion); } + private Version wantedVersion(ZoneId zone, SystemApplication application, Version defaultVersion) { + return minVersion(zone, application, Node::wantedOsVersion).orElse(defaultVersion); + } + /** Returns whether node in application should be upgraded by this */ public static boolean eligibleForUpgrade(Node node, SystemApplication application) { return upgradableNodeStates.contains(node.state()) && diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java index 763d26834e6..58e0b8dbeec 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java @@ -213,7 +213,7 @@ public class ApplicationSerializer { private void toSlime(ApplicationVersion applicationVersion, Cursor object) { if (applicationVersion.buildNumber().isPresent() && applicationVersion.source().isPresent()) { - object.setLong(applicationBuildNumberField, applicationVersion.buildNumber().get()); + object.setLong(applicationBuildNumberField, applicationVersion.buildNumber().getAsLong()); toSlime(applicationVersion.source().get(), object.setObject(sourceRevisionField)); } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java index e117592d608..0d8ea8d2537 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java @@ -302,12 +302,17 @@ public class CuratorDb { } public List<Tenant> readTenants() { + return readTenantNames().stream() + .map(this::readTenant) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(collectingAndThen(Collectors.toList(), Collections::unmodifiableList)); + } + + public List<TenantName> readTenantNames() { return curator.getChildren(tenantRoot).stream() .map(TenantName::from) - .map(this::readTenant) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(collectingAndThen(Collectors.toList(), Collections::unmodifiableList)); + .collect(Collectors.toList()); } public void removeTenant(TenantName name) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/LogSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/LogSerializer.java index 457ef761c0f..17b4a42fb91 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/LogSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/LogSerializer.java @@ -27,7 +27,6 @@ import java.util.stream.Collectors; class LogSerializer { private static final String idField = "id"; - private static final String levelField = "level"; private static final String typeField = "type"; private static final String timestampField = "at"; private static final String messageField = "message"; @@ -54,7 +53,6 @@ class LogSerializer { private void toSlime(LogEntry entry, Cursor entryObject) { entryObject.setLong(idField, entry.id()); entryObject.setLong(timestampField, entry.at()); - entryObject.setString(levelField, valueOf(entry.type())); // TODO jvenstad: Remove after one deployment. entryObject.setString(typeField, valueOf(entry.type())); entryObject.setString(messageField, entry.message()); } @@ -87,9 +85,7 @@ class LogSerializer { private LogEntry fromSlime(Inspector entryObject) { return new LogEntry(entryObject.field(idField).asLong(), entryObject.field(timestampField).asLong(), - entryObject.field(typeField).valid() // TODO jvenstad: Remove after one deployment. - ? typeOf(entryObject.field(typeField).asString()) - : typeOf(entryObject.field(levelField).asString()), + typeOf(entryObject.field(typeField).asString()), entryObject.field(messageField).asString()); } @@ -105,7 +101,7 @@ class LogSerializer { } static Type typeOf(String type) { - switch (type.toLowerCase()) { // TODO jvenstad: Remove lowercasing after this has been deployed. + switch (type) { case "debug": return Type.debug; case "info": return Type.info; case "warning": return Type.warning; diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java index d55dc791462..28400b85306 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.controller.persistence; import com.yahoo.config.provision.TenantName; +import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; import com.yahoo.slime.Slime; @@ -11,8 +12,12 @@ import com.yahoo.vespa.config.SlimeUtils; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; import com.yahoo.vespa.hosted.controller.tenant.UserTenant; +import java.net.URI; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; /** @@ -26,6 +31,12 @@ public class TenantSerializer { private static final String athenzDomainField = "athenzDomain"; private static final String propertyField = "property"; private static final String propertyIdField = "propertyId"; + private static final String contactField = "contact"; + private static final String contactUrlField = "contactUrl"; + private static final String propertyUrlField = "propertyUrl"; + private static final String issueTrackerUrlField = "issueTrackerUrl"; + private static final String personsField = "persons"; + private static final String personField = "person"; public Slime toSlime(AthenzTenant tenant) { Slime slime = new Slime(); @@ -34,6 +45,20 @@ public class TenantSerializer { root.setString(athenzDomainField, tenant.domain().getName()); root.setString(propertyField, tenant.property().id()); tenant.propertyId().ifPresent(propertyId -> root.setString(propertyIdField, propertyId.id())); + tenant.contact().ifPresent(contact -> { + Cursor contactObject = root.setObject(contactField); + contactObject.setString(contactUrlField, contact.url().toString()); + contactObject.setString(propertyUrlField, contact.propertyUrl().toString()); + contactObject.setString(issueTrackerUrlField, contact.issueTrackerUrl().toString()); + Cursor personsArray = contactObject.setArray(personsField); + contact.persons().forEach(personList -> { + Cursor personArray = personsArray.addArray(); + personList.forEach(person -> { + Cursor personObject = personArray.addObject(); + personObject.setString(personField, person); + }); + }); + }); return slime; } @@ -50,7 +75,8 @@ public class TenantSerializer { AthenzDomain domain = new AthenzDomain(root.field(athenzDomainField).asString()); Property property = new Property(root.field(propertyField).asString()); Optional<PropertyId> propertyId = SlimeUtils.optionalString(root.field(propertyIdField)).map(PropertyId::new); - return new AthenzTenant(name, domain, property, propertyId); + Optional<Contact> contact = contactFrom(root.field(contactField)); + return new AthenzTenant(name, domain, property, propertyId, contact); } public UserTenant userTenantFrom(Slime slime) { @@ -59,4 +85,24 @@ public class TenantSerializer { return new UserTenant(name); } + private Optional<Contact> contactFrom(Inspector object) { + if (!object.valid()) { + return Optional.empty(); + } + return Optional.of(new Contact(URI.create(object.field(contactUrlField).asString()), + URI.create(object.field(propertyUrlField).asString()), + URI.create(object.field(issueTrackerUrlField).asString()), + personsFrom(object.field(personsField)))); + } + + private List<List<String>> personsFrom(Inspector array) { + List<List<String>> personLists = new ArrayList<>(); + array.traverse((ArrayTraverser) (i, personArray) -> { + List<String> persons = new ArrayList<>(); + personArray.traverse((ArrayTraverser) (j, inspector) -> persons.add(inspector.field("person").asString())); + personLists.add(persons); + }); + return personLists; + } + } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index 07286fda90b..034db3d487d 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -16,6 +16,7 @@ import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.LoggingRequestHandler; import com.yahoo.io.IOUtils; import com.yahoo.log.LogLevel; +import com.yahoo.restapi.Path; import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; import com.yahoo.slime.Slime; @@ -50,7 +51,6 @@ import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServ import com.yahoo.vespa.hosted.controller.api.integration.configserver.Log; import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.api.integration.deployment.RunId; -import com.yahoo.vespa.hosted.controller.api.integration.organization.User; import com.yahoo.vespa.hosted.controller.api.integration.routing.RotationStatus; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; @@ -66,7 +66,6 @@ import com.yahoo.vespa.hosted.controller.application.JobStatus; import com.yahoo.vespa.hosted.controller.application.SourceRevision; import com.yahoo.vespa.hosted.controller.restapi.ErrorResponse; import com.yahoo.vespa.hosted.controller.restapi.MessageResponse; -import com.yahoo.restapi.Path; import com.yahoo.vespa.hosted.controller.restapi.ResourceResponse; import com.yahoo.vespa.hosted.controller.restapi.SlimeJsonResponse; import com.yahoo.vespa.hosted.controller.restapi.StringResponse; @@ -169,6 +168,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { if (path.matches("/application/v4/tenant/{tenant}")) return tenant(path.get("tenant"), request); if (path.matches("/application/v4/tenant/{tenant}/application")) return applications(path.get("tenant"), request); if (path.matches("/application/v4/tenant/{tenant}/application/{application}")) return application(path.get("tenant"), path.get("application"), request); + if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/logs")) return logs(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region")); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/job")) return JobControllerApiHandlerHelper.jobTypeResponse(controller, appIdFromPath(path), request.getUri()); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/job/{jobtype}")) return JobControllerApiHandlerHelper.runResponse(controller.jobController().runs(appIdFromPath(path), jobTypeFromPath(path)), request.getUri()); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/job/{jobtype}/run/{number}")) return JobControllerApiHandlerHelper.runDetailsResponse(controller.jobController(), runIdFromPath(path), request.getProperty("after")); @@ -346,9 +346,22 @@ public class ApplicationApiHandler extends LoggingRequestHandler { return new SlimeJsonResponse(slime); } + private HttpResponse logs(String tenantName, String applicationName, String instanceName, String environment, String region) { + ApplicationId application = ApplicationId.from(tenantName, applicationName, instanceName); + ZoneId zone = ZoneId.from(environment, region); + DeploymentId deployment = new DeploymentId(application, zone); + return controller.configServer().getLogs(deployment); + } + + private void toSlime(Cursor object, Application application, HttpRequest request) { object.setString("application", application.id().application().value()); object.setString("instance", application.id().instance().value()); + object.setString("deployments", withPath("/application/v4" + + "/tenant/" + application.id().tenant().value() + + "/application/" + application.id().application().value() + + "/instance/" + application.id().instance().value() + "/job/", + request.getUri()).toString()); // Currently deploying change if (application.change().isPresent()) { @@ -642,19 +655,27 @@ public class ApplicationApiHandler extends LoggingRequestHandler { } private HttpResponse updateTenant(String tenantName, HttpRequest request) { - Optional<AthenzTenant> existingTenant = controller.tenants().athenzTenant(TenantName.from(tenantName)); - if ( ! existingTenant.isPresent()) return ErrorResponse.notFoundError("Tenant '" + tenantName + "' does not exist"); + Optional<AthenzTenant> tenant = controller.tenants().athenzTenant(TenantName.from(tenantName)); + if ( ! tenant.isPresent()) return ErrorResponse.notFoundError("Tenant '" + tenantName + "' does not exist"); Inspector requestData = toSlime(request.getData()).get(); - AthenzTenant updatedTenant = existingTenant.get() - .with(new AthenzDomain(mandatory("athensDomain", requestData).asString())) - .with(new Property(mandatory("property", requestData).asString())); - Optional<PropertyId> propertyId = optional("propertyId", requestData).map(PropertyId::new); - if (propertyId.isPresent()) { - updatedTenant = updatedTenant.with(propertyId.get()); - } - controller.tenants().updateTenant(updatedTenant, requireNToken(request, "Could not update " + tenantName)); - return tenant(updatedTenant, request, true); + NToken token = requireNToken(request, "Could not update " + tenantName); + + controller.tenants().lockOrThrow(tenant.get().name(), lockedTenant -> { + lockedTenant = lockedTenant.with(new Property(mandatory("property", requestData).asString())); + lockedTenant = controller.tenants().withDomain( + lockedTenant, + new AthenzDomain(mandatory("athensDomain", requestData).asString()), + token + ); + Optional<PropertyId> propertyId = optional("propertyId", requestData).map(PropertyId::new); + if (propertyId.isPresent()) { + lockedTenant = lockedTenant.with(propertyId.get()); + } + controller.tenants().store(lockedTenant); + }); + + return tenant(controller.tenants().requireAthenzTenant(tenant.get().name()), request, true); } private HttpResponse createTenant(String tenantName, HttpRequest request) { @@ -897,13 +918,11 @@ public class ApplicationApiHandler extends LoggingRequestHandler { private void toSlime(Cursor object, Tenant tenant, HttpRequest request, boolean listApplications) { object.setString("tenant", tenant.name().value()); object.setString("type", tentantType(tenant)); - Optional<PropertyId> propertyId = Optional.empty(); if (tenant instanceof AthenzTenant) { AthenzTenant athenzTenant = (AthenzTenant) tenant; object.setString("athensDomain", athenzTenant.domain().getName()); object.setString("property", athenzTenant.property().id()); - propertyId = athenzTenant.propertyId(); - propertyId.ifPresent(id -> object.setString("propertyId", id.toString())); + athenzTenant.propertyId().ifPresent(id -> object.setString("propertyId", id.toString())); } Cursor applicationArray = object.setArray("applications"); if (listApplications) { // This cludge is needed because we call this after deleting the tenant. As this call makes another tenant lookup it will fail. TODO is to support lookup on tenant @@ -916,23 +935,19 @@ public class ApplicationApiHandler extends LoggingRequestHandler { } } } - propertyId.ifPresent(id -> { - try { - object.setString("propertyUrl", controller.organization().propertyUri(id).toString()); - object.setString("contactsUrl", controller.organization().contactsUri(id).toString()); - object.setString("issueCreationUrl", controller.organization().issueCreationUri(id).toString()); - Cursor lists = object.setArray("contacts"); - for (List<? extends User> contactList : controller.organization().contactsFor(id)) { - Cursor list = lists.addArray(); - for (User contact : contactList) - list.addString(contact.displayName()); - } - } - catch (RuntimeException e) { - log.log(Level.WARNING, "Error fetching property info for " + tenant + " with propertyId " + id + ": " + - Exceptions.toMessageString(e)); - } - }); + if (tenant instanceof AthenzTenant) { + AthenzTenant athenzTenant = (AthenzTenant) tenant; + athenzTenant.contact().ifPresent(c -> { + object.setString("propertyUrl", c.propertyUrl().toString()); + object.setString("contactsUrl", c.url().toString()); + object.setString("issueCreationUrl", c.issueTrackerUrl().toString()); + Cursor contactsArray = object.setArray("contacts"); + c.persons().forEach(persons -> { + Cursor personArray = contactsArray.addArray(); + persons.forEach(personArray::addString); + }); + }); + } } // A tenant has different content when in a list ... antipattern, but not solvable before application/v5 diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java index 620a3514c87..4adea3383c5 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java @@ -4,6 +4,7 @@ import com.google.common.base.Joiner; import com.yahoo.component.Version; import com.yahoo.config.application.api.DeploymentSpec; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.SystemName; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.slime.Cursor; import com.yahoo.slime.Slime; @@ -119,7 +120,7 @@ class JobControllerApiHandlerHelper { Cursor jobsObject = responseObject.setObject("jobs"); steps.jobs().forEach(type -> { - jobTypeToSlime(jobsObject.setObject(type.jobName()), + jobTypeToSlime(jobsObject.setObject(shortNameOf(type, controller.system())), controller, application, type, @@ -142,13 +143,13 @@ class JobControllerApiHandlerHelper { lastVespa = version; Version lastPlatform = lastVespa.versionNumber(); - lastPlatformObject.setString("version", lastPlatform.toString()); + lastPlatformObject.setString("platform", lastPlatform.toString()); lastPlatformObject.setLong("at", lastVespa.committedAt().toEpochMilli()); long completed = steps.productionJobs().stream().filter(type -> controller.applications().deploymentTrigger().isComplete(Change.of(lastPlatform), application, type)).count(); if (Optional.of(lastPlatform).equals(change.platform())) - lastPlatformObject.setString("deploying", completed + " of " + steps.productionJobs().size()); + lastPlatformObject.setString("deploying", completed + " of " + steps.productionJobs().size() + " complete"); else if (completed == steps.productionJobs().size()) - lastPlatformObject.setString("completed", completed + " of " + steps.productionJobs().size()); + lastPlatformObject.setString("completed", completed + " of " + steps.productionJobs().size() + " complete"); else if ( ! application.deploymentSpec().canUpgradeAt(controller.clock().instant())) { lastPlatformObject.setString("blocked", application.deploymentSpec().changeBlocker().stream() .filter(blocker -> blocker.blocksVersions()) @@ -156,13 +157,16 @@ class JobControllerApiHandlerHelper { .findAny().map(blocker -> blocker.window().toString()).get()); } else - lastPlatformObject.setString("pending", "Waiting for current deployment to complete"); + lastPlatformObject.setString("pending", + application.changeAt(controller.clock().instant()).isPresent() + ? "Waiting for current deployment to complete" + : "Waiting for upgrade slot"); } private static void lastApplicationToSlime(Cursor lastApplicationObject, Application application, Change change, DeploymentSteps steps, Controller controller) { long completed; ApplicationVersion lastApplication = application.deploymentJobs().statusOf(component).flatMap(JobStatus::lastSuccess).get().application(); - applicationVersionToSlime(lastApplicationObject.setObject("version"), lastApplication); + applicationVersionToSlime(lastApplicationObject.setObject("application"), lastApplication); lastApplicationObject.setLong("at", application.deploymentJobs().statusOf(component).flatMap(JobStatus::lastSuccess).get().at().toEpochMilli()); completed = steps.productionJobs().stream().filter(type -> controller.applications().deploymentTrigger().isComplete(Change.of(lastApplication), application, type)).count(); if (Optional.of(lastApplication).equals(change.application())) @@ -241,11 +245,11 @@ class JobControllerApiHandlerHelper { if ( ! controller.applications().deploymentTrigger().alreadyTriggered(application, versions)) { if ( ! controller.applications().deploymentTrigger().testedIn(application, systemTest, versions)) { pending++; - pendingObject.setString(systemTest.jobName(), statusOf(controller, application.id(), systemTest, versions)); + pendingObject.setString(shortNameOf(systemTest, controller.system()), statusOf(controller, application.id(), systemTest, versions)); } if ( ! controller.applications().deploymentTrigger().testedIn(application, stagingTest, versions)) { pending++; - pendingObject.setString(stagingTest.jobName(), statusOf(controller, application.id(), stagingTest, versions)); + pendingObject.setString(shortNameOf(stagingTest, controller.system()), statusOf(controller, application.id(), stagingTest, versions)); } } steps: for (DeploymentSpec.Step step : steps.production()) { @@ -253,7 +257,11 @@ class JobControllerApiHandlerHelper { break; for (JobType stepType : steps.toJobs(step)) { if (pendingProduction.containsKey(stepType)) { - pendingObject.setString(stepType.jobName(), statusOf(controller, application.id(), stepType, versions)); + Versions jobVersions = Versions.from(application.changeAt(controller.clock().instant()), + application, + Optional.ofNullable(application.deployments().get(stepType.zone(controller.system()))), + controller.systemVersion()); + pendingObject.setString(shortNameOf(stepType, controller.system()), statusOf(controller, application.id(), stepType, jobVersions)); if (++pending == 3) break steps; } @@ -274,12 +282,16 @@ class JobControllerApiHandlerHelper { private static String statusOf(Controller controller, ApplicationId id, JobType type, Versions versions) { return controller.jobController().last(id, type) - .filter(run -> versions.targetsMatch(versions)) - .filter(run -> type == systemTest || versions.sourcesMatchIfPresent(versions)) + .filter(run -> run.versions().targetsMatch(versions)) + .filter(run -> type != stagingTest || run.versions().sourcesMatchIfPresent(versions)) .map(JobControllerApiHandlerHelper::taskStatusOf) .orElse("pending"); } + private static String shortNameOf(JobType type, SystemName system) { + return type.isProduction() ? type.zone(system).region().value() : type.jobName(); + } + private static String taskStatusOf(Run run) { switch (run.status()) { case running: return "running"; @@ -333,11 +345,12 @@ class JobControllerApiHandlerHelper { } private static void applicationVersionToSlime(Cursor versionObject, ApplicationVersion version) { - versionObject.setString("id", version.id()); - versionObject.setLong("build", version.buildNumber().get()); - versionObject.setString("repository", version.source().get().repository()); - versionObject.setString("branch", version.source().get().branch()); - versionObject.setString("commit", version.source().get().commit()); + versionObject.setString("hash", version.id()); + versionObject.setLong("build", version.buildNumber().getAsLong()); + Cursor sourceObject = versionObject.setObject("source"); + sourceObject.setString("gitRepository", version.source().get().repository()); + sourceObject.setString("gitBranch", version.source().get().branch()); + sourceObject.setString("gitCommit", version.source().get().commit()); } /** diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java index 0ba0eea2dab..8cbb4e06aca 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java @@ -6,6 +6,7 @@ import com.yahoo.vespa.athenz.api.AthenzDomain; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; +import java.util.Objects; import java.util.Optional; /** @@ -18,16 +19,19 @@ public class AthenzTenant extends Tenant { private final AthenzDomain domain; private final Property property; private final Optional<PropertyId> propertyId; + private final Optional<Contact> contact; /** * This should only be used by serialization. * Use {@link #create(TenantName, AthenzDomain, Property, Optional)}. * */ - public AthenzTenant(TenantName name, AthenzDomain domain, Property property, Optional<PropertyId> propertyId) { + public AthenzTenant(TenantName name, AthenzDomain domain, Property property, Optional<PropertyId> propertyId, + Optional<Contact> contact) { super(name); - this.domain = domain; - this.property = property; - this.propertyId = propertyId; + this.domain = Objects.requireNonNull(domain, "domain must be non-null"); + this.property = Objects.requireNonNull(property, "property must be non-null"); + this.propertyId = Objects.requireNonNull(propertyId, "propertyId must be non-null"); + this.contact = Objects.requireNonNull(contact, "contact must be non-null"); } /** Property name of this tenant */ @@ -35,11 +39,16 @@ public class AthenzTenant extends Tenant { return property; } - /** Property ID of the tenant, if present */ + /** Property ID of the tenant, if any */ public Optional<PropertyId> propertyId() { return propertyId; } + /** Contact information for this, if any */ + public Optional<Contact> contact() { + return contact; + } + /** Athenz domain of this tenant */ public AthenzDomain domain() { return domain; @@ -55,22 +64,10 @@ public class AthenzTenant extends Tenant { return "athenz tenant '" + name() + "'"; } - public AthenzTenant with(AthenzDomain domain) { - return new AthenzTenant(name(), domain, property(), propertyId()); - } - - public AthenzTenant with(Property property) { - return new AthenzTenant(name(), domain, property, propertyId()); - } - - public AthenzTenant with(PropertyId propertyId) { - return new AthenzTenant(name(), domain, property, Optional.of(propertyId)); - } - /** Create a new Athenz tenant */ public static AthenzTenant create(TenantName name, AthenzDomain domain, Property property, Optional<PropertyId> propertyId) { - return new AthenzTenant(requireName(requireNoPrefix(name)), domain, property, propertyId); + return new AthenzTenant(requireName(requireNoPrefix(name)), domain, property, propertyId, Optional.empty()); } private static TenantName requireNoPrefix(TenantName name) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Contact.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Contact.java new file mode 100644 index 00000000000..e13b0f982da --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Contact.java @@ -0,0 +1,75 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.tenant; + +import com.google.common.collect.ImmutableList; + +import java.net.URI; +import java.util.List; +import java.util.Objects; + +/** + * Contact information for a tenant. + * + * @author mpolden + */ +public class Contact { + + private final URI url; + private final URI propertyUrl; + private final URI issueTrackerUrl; + private final List<List<String>> persons; + + public Contact(URI url, URI propertyUrl, URI issueTrackerUrl, List<List<String>> persons) { + this.propertyUrl = Objects.requireNonNull(propertyUrl, "propertyUrl must be non-null"); + this.url = Objects.requireNonNull(url, "url must be non-null"); + this.issueTrackerUrl = Objects.requireNonNull(issueTrackerUrl, "issueTrackerUrl must be non-null"); + this.persons = ImmutableList.copyOf(Objects.requireNonNull(persons, "persons must be non-null")); + } + + /** URL to this */ + public URI url() { + return url; + } + + /** URL to information about this property */ + public URI propertyUrl() { + return propertyUrl; + } + + /** URL to this contacts's issue tracker */ + public URI issueTrackerUrl() { + return issueTrackerUrl; + } + + /** Nested list of persons representing this. First level represents that person's rank in the corporate dystopia. */ + public List<List<String>> persons() { + return persons; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Contact contact = (Contact) o; + return Objects.equals(url, contact.url) && + Objects.equals(propertyUrl, contact.propertyUrl) && + Objects.equals(issueTrackerUrl, contact.issueTrackerUrl) && + Objects.equals(persons, contact.persons); + } + + @Override + public int hashCode() { + return Objects.hash(url, propertyUrl, issueTrackerUrl, persons); + } + + @Override + public String toString() { + return "Contact{" + + "url=" + url + + ", propertyUrl=" + propertyUrl + + ", issueTrackerUrl=" + issueTrackerUrl + + ", persons=" + persons + + '}'; + } + +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java index c067bccb4c3..61b921aa6c1 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java @@ -78,6 +78,7 @@ public final class ControllerTester { private final MockBuildService buildService; private final MetricsServiceMock metricsService; private final RoutingGeneratorMock routingGenerator; + private final MockOrganization organization; private Controller controller; @@ -87,7 +88,7 @@ public final class ControllerTester { new ZoneRegistryMock(), new GitHubMock(), curatorDb, rotationsConfig, new MemoryNameService(), new ArtifactRepositoryMock(), new ApplicationStoreMock(), new MemoryEntityService(), new MockBuildService(), - metricsService, new RoutingGeneratorMock()); + metricsService, new RoutingGeneratorMock(), new MockOrganization(clock)); } public ControllerTester(ManualClock clock) { @@ -112,7 +113,8 @@ public final class ControllerTester { MemoryNameService nameService, ArtifactRepositoryMock artifactRepository, ApplicationStoreMock appStoreMock, EntityService entityService, MockBuildService buildService, - MetricsServiceMock metricsService, RoutingGeneratorMock routingGenerator) { + MetricsServiceMock metricsService, RoutingGeneratorMock routingGenerator, + MockOrganization organization) { this.athenzDb = athenzDb; this.clock = clock; this.configServer = configServer; @@ -127,6 +129,7 @@ public final class ControllerTester { this.buildService = buildService; this.metricsService = metricsService; this.routingGenerator = routingGenerator; + this.organization = organization; this.controller = createController(curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, athenzDb, nameService, artifactRepository, appStoreMock, entityService, buildService, metricsService, routingGenerator); @@ -175,6 +178,10 @@ public final class ControllerTester { public RoutingGeneratorMock routingGenerator() { return routingGenerator; } + public MockOrganization organization() { + return organization; + } + /** Create a new controller instance. Useful to verify that controller state is rebuilt from persistence */ public final void createNewController() { controller = createController(curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, athenzDb, @@ -197,12 +204,6 @@ public final class ControllerTester { } /** Creates the given tenant and application and deploys it */ - public Application createAndDeploy(String tenantName, String domainName, String applicationName, - String instanceName, Environment environment, long projectId, Long propertyId) { - return createAndDeploy(tenantName, domainName, applicationName, instanceName, toZone(environment), projectId, propertyId); - } - - /** Creates the given tenant and application and deploys it */ public Application createAndDeploy(String tenantName, String domainName, String applicationName, ZoneId zone, long projectId, Long propertyId) { return createAndDeploy(tenantName, domainName, applicationName, "default", zone, projectId, propertyId); } @@ -300,7 +301,6 @@ public final class ControllerTester { rotationsConfig, gitHub, entityService, - new MockOrganization(clock), new MemoryGlobalRoutingService(), zoneRegistryMock, configServer, diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/concurrent/OnceTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/concurrent/OnceTest.java new file mode 100644 index 00000000000..e11fdcba7c6 --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/concurrent/OnceTest.java @@ -0,0 +1,25 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.concurrent; + +import org.junit.Test; + +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertTrue; + +/** + * @author mpolden + */ +public class OnceTest { + + @Test(timeout = 60_000) + public void test_run() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + Once.after(Duration.ZERO, latch::countDown); + + assertTrue(latch.await(30, TimeUnit.SECONDS)); + } + +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java index f5fc6825960..3b381e21b27 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java @@ -375,9 +375,9 @@ public class DeploymentTriggerTest { // This component completion should remove the older outstanding change, to avoid a later downgrade. clock.advance(Duration.ofHours(1)); tester.deployAndNotify(application, applicationPackage, true, productionUsWest1); - assertEquals((Long) BuildJob.defaultBuildNumber, tester.application(application.id()).deploymentJobs().jobStatus() - .get(productionUsWest1).lastSuccess().get().application().buildNumber().get()); - assertEquals((Long) (BuildJob.defaultBuildNumber + 1), tester.application(application.id()).outstandingChange().application().get().buildNumber().get()); + assertEquals(BuildJob.defaultBuildNumber, tester.application(application.id()).deploymentJobs().jobStatus() + .get(productionUsWest1).lastSuccess().get().application().buildNumber().getAsLong()); + assertEquals((BuildJob.defaultBuildNumber + 1), tester.application(application.id()).outstandingChange().application().get().buildNumber().getAsLong()); tester.readyJobTrigger().maintain(); assertTrue(tester.buildService().jobs().isEmpty()); @@ -513,14 +513,14 @@ public class DeploymentTriggerTest { tester.assertRunning(productionUsCentral1, application.id()); assertEquals(v2, app.get().deployments().get(productionUsCentral1.zone(main)).version()); - assertEquals(Long.valueOf(42L), app.get().deployments().get(productionUsCentral1.zone(main)).applicationVersion().buildNumber().get()); + assertEquals(42, app.get().deployments().get(productionUsCentral1.zone(main)).applicationVersion().buildNumber().getAsLong()); assertNotEquals(triggered, app.get().deploymentJobs().jobStatus().get(productionUsCentral1).lastTriggered().get().at()); // Change has a higher application version than what is deployed -- deployment should trigger. tester.deployAndNotify(application, applicationPackage, false, productionUsCentral1); tester.deploy(productionUsCentral1, application, applicationPackage); assertEquals(v2, app.get().deployments().get(productionUsCentral1.zone(main)).version()); - assertEquals(Long.valueOf(43), app.get().deployments().get(productionUsCentral1.zone(main)).applicationVersion().buildNumber().get()); + assertEquals(43, app.get().deployments().get(productionUsCentral1.zone(main)).applicationVersion().buildNumber().getAsLong()); // Change is again strictly dominated, and us-central-1 is skipped, even though it is still failing. tester.clock().advance(Duration.ofHours(2).plus(Duration.ofSeconds(1))); // Enough time for retry @@ -588,8 +588,8 @@ public class DeploymentTriggerTest { tester.deployAndNotify(application, true, productionUsEast3); tester.deployAndNotify(application, true, productionEuWest1); assertFalse(app.get().change().isPresent()); - assertEquals(43, app.get().deploymentJobs().jobStatus().get(productionEuWest1).lastSuccess().get().application().buildNumber().get().longValue()); - assertEquals(43, app.get().deploymentJobs().jobStatus().get(productionUsEast3).lastSuccess().get().application().buildNumber().get().longValue()); + assertEquals(43, app.get().deploymentJobs().jobStatus().get(productionEuWest1).lastSuccess().get().application().buildNumber().getAsLong()); + assertEquals(43, app.get().deploymentJobs().jobStatus().get(productionUsEast3).lastSuccess().get().application().buildNumber().getAsLong()); } @Test diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunnerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunnerTest.java index a83415d902d..f33f82b78e2 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunnerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunnerTest.java @@ -182,14 +182,6 @@ public class InternalStepRunnerTest { } @Test - public void testsFailIfTesterEndpointsVanish() { - RunId id = tester.startSystemTestTests(); - tester.routing().removeEndpoints(new DeploymentId(testerOf(InternalDeploymentTester.appId), JobType.systemTest.zone(tester.tester().controller().system()))); - tester.runner().run(); - assertEquals(failed, tester.jobs().run(id).get().steps().get(Step.endTests)); - } - - @Test public void testsFailIfTesterRestarts() { RunId id = tester.startSystemTestTests(); tester.cloud().set(TesterCloud.Status.NOT_STARTED); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java index dc9f3246e80..bd65465633e 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java @@ -7,6 +7,7 @@ import com.yahoo.component.Version; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.HostName; import com.yahoo.config.provision.NodeType; +import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions; import com.yahoo.vespa.hosted.controller.api.application.v4.model.EndpointStatus; import com.yahoo.vespa.hosted.controller.api.application.v4.model.configserverbindings.ConfigChangeActions; @@ -25,7 +26,11 @@ import com.yahoo.vespa.hosted.controller.application.SystemApplication; import com.yahoo.vespa.serviceview.bindings.ApplicationView; import com.yahoo.vespa.serviceview.bindings.ClusterView; import com.yahoo.vespa.serviceview.bindings.ServiceView; +import org.json.JSONException; +import org.json.JSONObject; +import java.io.IOException; +import java.io.OutputStream; import java.net.URI; import java.util.ArrayList; import java.util.Arrays; @@ -110,7 +115,17 @@ public class ConfigServerMock extends AbstractComponent implements ConfigServer /** Converge all services belonging to the given application */ public void convergeServices(ApplicationId application, ZoneId zone) { - serviceStatus.put(new DeploymentId(application, zone), new ServiceConvergence(application, zone, true)); + List<Node> nodes = nodeRepository.list(zone, application); + serviceStatus.put(new DeploymentId(application, zone), new ServiceConvergence(application, + zone, + true, + 2, + nodes.stream() + .map(node -> new ServiceConvergence.Status(node.hostname(), + 43, + "container", + 2)) + .collect(Collectors.toList()))); } /** The version given in the previous prepare call, or empty if no call has been made */ @@ -189,14 +204,24 @@ public class ConfigServerMock extends AbstractComponent implements ConfigServer public PrepareResponse prepareResponse() { Application application = applications.get(deployment.applicationId()); application.activate(); - for (Node node : nodeRepository.list(deployment.zoneId(), deployment.applicationId())) { + List<Node> nodes = nodeRepository.list(deployment.zoneId(), deployment.applicationId()); + for (Node node : nodes) { nodeRepository.putByHostname(deployment.zoneId(), new Node(node.hostname(), node.state(), node.type(), node.owner(), node.currentVersion(), application.version().get())); } - serviceStatus.remove(deployment); // Deployment is no longer converging after new deployment + serviceStatus.put(deployment, new ServiceConvergence(deployment.applicationId(), + deployment.zoneId(), + false, + 2, + nodes.stream() + .map(node -> new ServiceConvergence.Status(node.hostname(), + 43, + "container", + 1)) + .collect(Collectors.toList()))); PrepareResponse prepareResponse = new PrepareResponse(); prepareResponse.message = "foo"; @@ -223,6 +248,7 @@ public class ConfigServerMock extends AbstractComponent implements ConfigServer applications.remove(deployment.applicationId()); nodeRepository().removeByHostname(deployment.zoneId(), nodeRepository().list(deployment.zoneId(), deployment.applicationId())); + serviceStatus.remove(deployment); } // Returns a canned example response @@ -271,6 +297,17 @@ public class ConfigServerMock extends AbstractComponent implements ConfigServer return endpoints.getOrDefault(endpoint, result); } + @Override + public HttpResponse getLogs(DeploymentId deployment) { + return new HttpResponse(200) { + @Override + public void render(OutputStream outputStream) throws IOException { + outputStream.write("{\"subfolder\":{\"log2.log\":\"VGhpcyBpcyBhbm90aGVyIGxvZyBmaWxl\"},\"log1.log\":\"VGhpcyBpcyBvbmUgbG9nIGZpbGU=\"}".getBytes()); + } + }; + + } + public static class Application { private final ApplicationId id; diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainerTest.java new file mode 100644 index 00000000000..cbaa37b15e3 --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainerTest.java @@ -0,0 +1,77 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.maintenance; + +import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.hosted.controller.ControllerTester; +import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; +import com.yahoo.vespa.hosted.controller.api.integration.organization.User; +import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; +import org.junit.Before; +import org.junit.Test; + +import java.net.URI; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * @author mpolden + */ +public class ContactInformationMaintainerTest { + + private ControllerTester tester; + private ContactInformationMaintainer maintainer; + + @Before + public void before() { + tester = new ControllerTester(); + maintainer = new ContactInformationMaintainer(tester.controller(), Duration.ofDays(1), + new JobControl(tester.controller().curator()), + tester.organization()); + } + + @Test + public void updates_contact_information() { + long propertyId = 1; + TenantName name = tester.createTenant("tenant1", "domain1", propertyId); + Supplier<AthenzTenant> tenant = () -> tester.controller().tenants().requireAthenzTenant(name); + assertFalse("No contact information initially", tenant.get().contact().isPresent()); + + Contact contact = testContact(); + registerContact(propertyId, contact); + maintainer.run(); + + assertTrue("Contact information added", tenant.get().contact().isPresent()); + assertEquals(contact, tenant.get().contact().get()); + } + + private void registerContact(long propertyId, Contact contact) { + PropertyId p = new PropertyId(String.valueOf(propertyId)); + tester.organization().addProperty(p) + .setContactsUrl(p, contact.url()) + .setIssueUrl(p, contact.issueTrackerUrl()) + .setPropertyUrl(p, contact.propertyUrl()) + .setContactsFor(p, contact.persons().stream().map(persons -> persons.stream() + .map(User::from) + .collect(Collectors.toList())) + .collect(Collectors.toList())); + } + + private static Contact testContact() { + URI contactUrl = URI.create("http://contact1.test"); + URI issueTrackerUrl = URI.create("http://issue-tracker1.test"); + URI propertyUrl = URI.create("http://property1.test"); + List<List<String>> persons = Arrays.asList(Collections.singletonList("alice"), + Collections.singletonList("bob")); + return new Contact(contactUrl, propertyUrl, issueTrackerUrl, persons); + } + +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/JobRunnerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/JobRunnerTest.java index a77f0789314..0b2863dab1d 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/JobRunnerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/JobRunnerTest.java @@ -75,7 +75,6 @@ public class JobRunnerTest { public void multiThreadedExecutionFinishes() throws InterruptedException { DeploymentTester tester = new DeploymentTester(); JobController jobs = tester.controller().jobController(); - // Fail the installation of the initial version of the real application in staging tests, and succeed everything else. StepRunner stepRunner = (step, id) -> id.type() == stagingTest && step.get() == startTests? Optional.of(error) : Optional.of(running); CountDownLatch latch = new CountDownLatch(19); // Number of steps that will run, below: all but endTests in staging and all 9 in system. JobRunner runner = new JobRunner(tester.controller(), Duration.ofDays(1), new JobControl(tester.controller().curator()), @@ -93,9 +92,10 @@ public class JobRunnerTest { jobs.start(id, stagingTest, versions); assertTrue(jobs.last(id, systemTest).get().steps().values().stream().allMatch(unfinished::equals)); - runner.maintain(); assertFalse(jobs.last(id, systemTest).get().hasEnded()); + assertTrue(jobs.last(id, stagingTest).get().steps().values().stream().allMatch(unfinished::equals)); assertFalse(jobs.last(id, stagingTest).get().hasEnded()); + runner.maintain(); latch.await(1, TimeUnit.SECONDS); assertEquals(0, latch.getCount()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/OsUpgraderTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/OsUpgraderTest.java index 045386dd93a..74f7ab7faf2 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/OsUpgraderTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/OsUpgraderTest.java @@ -119,33 +119,6 @@ public class OsUpgraderTest { .allMatch(node -> node.version().equals(version1))); } - // TODO: Remove once enabled in all systems - @Test - public void os_upgrade_in_main_does_nothing() { - OsUpgrader osUpgrader = osUpgrader( - UpgradePolicy.create() - .upgrade(zone1) - .upgradeInParallel(zone2, zone3) - .upgrade(zone4), - SystemName.main - ); - - // Bootstrap system - tester.configServer().bootstrap(Arrays.asList(zone1, zone2, zone3, zone4, zone5), - singletonList(SystemApplication.zone), - Optional.of(NodeType.host)); - - // New OS is released - CloudName cloud = CloudName.defaultName(); - Version version1 = Version.fromString("7.1"); - tester.controller().upgradeOsIn(cloud, version1); - statusUpdater.maintain(); - - // Nothing happens as main is explicitly disabled - osUpgrader.maintain(); - assertWanted(Version.emptyVersion, SystemApplication.zone, zone1); - } - private List<OsVersionStatus.Node> nodesOn(Version version) { return tester.controller().osVersionStatus().versions().entrySet().stream() .filter(entry -> entry.getKey().version().equals(version)) diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/SystemUpgraderTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/SystemUpgraderTest.java index 4f45b25b1a1..bb52b8fe20f 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/SystemUpgraderTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/SystemUpgraderTest.java @@ -88,9 +88,12 @@ public class SystemUpgraderTest { assertWantedVersion(SystemApplication.zone, version1, zone2, zone3, zone4); // zone 2 and 3: upgrade does not start until zone 1 zone-application config converges + // TODO: Commented out for now, see comment in SystemApplication.configConvergedIn() + /* systemUpgrader.maintain(); assertWantedVersion(SystemApplication.configServer, version1, zone2, zone3); convergeServices(SystemApplication.zone, zone1); + */ // zone 2 and 3: zone-config-server upgrades, first in zone 2, then in zone 3 systemUpgrader.maintain(); @@ -127,11 +130,14 @@ public class SystemUpgraderTest { completeUpgrade(SystemApplication.zone, version2, zone4); // zone 4: System version remains unchanged until config converges + // TODO: Commented out for now, see comment in SystemApplication.configConvergedIn() + /* tester.computeVersionStatus(); assertSystemVersion(version1); convergeServices(SystemApplication.zone, zone4); tester.computeVersionStatus(); assertSystemVersion(version2); + */ // Next run does nothing as system is now upgraded systemUpgrader.maintain(); @@ -159,12 +165,16 @@ public class SystemUpgraderTest { systemUpgrader.maintain(); completeUpgrade(SystemApplication.zone, version2, zone1); tester.computeVersionStatus(); - assertSystemVersion(version1); // Unchanged until zone-application converges + // TODO: Changed for now, see comment in SystemApplication.configConvergedIn() + //assertSystemVersion(version1); // Unchanged until zone-application converges + assertSystemVersion(version2); // Controller upgrades again Version version3 = Version.fromString("6.7"); tester.upgradeController(version3); - assertSystemVersion(version1); + // TODO: Changed for now, see todo above + //assertSystemVersion(version1); + assertSystemVersion(version2); assertControllerVersion(version3); // zone 1: zone-application converges and system version changes diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/RunSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/RunSerializerTest.java index 82aee3b3550..de9fe3f3dcc 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/RunSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/RunSerializerTest.java @@ -30,12 +30,13 @@ import static com.yahoo.vespa.hosted.controller.deployment.Step.deactivateTester import static com.yahoo.vespa.hosted.controller.deployment.Step.deployInitialReal; import static com.yahoo.vespa.hosted.controller.deployment.Step.deployReal; import static com.yahoo.vespa.hosted.controller.deployment.Step.deployTester; +import static com.yahoo.vespa.hosted.controller.deployment.Step.endTests; import static com.yahoo.vespa.hosted.controller.deployment.Step.installInitialReal; import static com.yahoo.vespa.hosted.controller.deployment.Step.installReal; import static com.yahoo.vespa.hosted.controller.deployment.Step.installTester; import static com.yahoo.vespa.hosted.controller.deployment.Step.report; import static com.yahoo.vespa.hosted.controller.deployment.Step.startTests; -import static com.yahoo.vespa.hosted.controller.deployment.Step.endTests; +import static java.time.temporal.ChronoUnit.MILLIS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -97,7 +98,7 @@ public class RunSerializerTest { .build(), run.steps()); - run = run.aborted().finished(Instant.now()); + run = run.aborted().finished(Instant.now().truncatedTo(MILLIS)); assertEquals(aborted, run.status()); assertTrue(run.hasEnded()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java index fd909482072..38b09024cdf 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java @@ -5,9 +5,13 @@ import com.yahoo.vespa.athenz.api.AthenzDomain; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; import com.yahoo.vespa.hosted.controller.tenant.UserTenant; import org.junit.Test; +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -47,6 +51,25 @@ public class TenantSerializerTest { } @Test + public void athenz_tenant_with_contact() { + AthenzTenant tenant = new AthenzTenant(TenantName.from("athenz-tenant"), + new AthenzDomain("domain1"), + new Property("property1"), + Optional.of(new PropertyId("1")), + Optional.of(new Contact( + URI.create("http://contact1.test"), + URI.create("http://property1.test"), + URI.create("http://issue-tracker-1.test"), + Arrays.asList( + Collections.singletonList("person1"), + Collections.singletonList("person2") + ) + ))); + AthenzTenant serialized = serializer.athenzTenantFrom(serializer.toSlime(tenant)); + assertEquals(tenant.contact(), serialized.contact()); + } + + @Test public void user_tenant() { UserTenant tenant = UserTenant.create("by-foo"); UserTenant serialized = serializer.userTenantFrom(serializer.toSlime(tenant)); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/testdata/logs.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/testdata/logs.json index a6a092109a1..ce9bd2139c7 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/testdata/logs.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/testdata/logs.json @@ -3,13 +3,13 @@ [ { "id": 0, - "level": "info", + "type": "info", "at": 0, "message": "First" }, { "id": 2, - "level": "debug", + "type": "debug", "at": 1000, "message": "Third" } @@ -18,13 +18,13 @@ [ { "id": 1, - "level": "info", + "type": "info", "at": 0, "message": "Second" }, { "id": 3, - "level": "warning", + "type": "warning", "at": 2000, "message": "Fourth" } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java index 017479ecc90..0ea23ae1b78 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java @@ -46,6 +46,8 @@ import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzDbMock; import com.yahoo.vespa.hosted.controller.deployment.ApplicationPackageBuilder; import com.yahoo.vespa.hosted.controller.deployment.BuildJob; import com.yahoo.vespa.hosted.controller.integration.ConfigServerMock; +import com.yahoo.vespa.hosted.controller.maintenance.ContactInformationMaintainer; +import com.yahoo.vespa.hosted.controller.maintenance.JobControl; import com.yahoo.vespa.hosted.controller.restapi.ContainerControllerTester; import com.yahoo.vespa.hosted.controller.restapi.ContainerTester; import com.yahoo.vespa.hosted.controller.restapi.ControllerContainerTest; @@ -53,6 +55,7 @@ import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; import org.apache.http.HttpEntity; import org.apache.http.entity.ContentType; import org.apache.http.entity.mime.MultipartEntityBuilder; +import org.junit.Before; import org.junit.Test; import java.io.ByteArrayOutputStream; @@ -61,6 +64,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.net.URI; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -108,10 +112,18 @@ public class ApplicationApiTest extends ControllerContainerTest { private static final ZoneId TEST_ZONE = ZoneId.from(Environment.test, RegionName.from("us-east-1")); private static final ZoneId STAGING_ZONE = ZoneId.from(Environment.staging, RegionName.from("us-east-3")); + + private ContainerControllerTester controllerTester; + private ContainerTester tester; + + @Before + public void before() { + controllerTester = new ContainerControllerTester(container, responseFiles); + tester = controllerTester.containerTester(); + } + @Test - public void testApplicationApi() throws Exception { - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); + public void testApplicationApi() { tester.computeVersionStatus(); createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, USER_ID); // (Necessary but not provided in this API) @@ -151,7 +163,8 @@ public class ApplicationApiTest extends ControllerContainerTest { // Add another Athens domain, so we can try to create more tenants createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN_2, USER_ID); // New domain to test tenant w/property ID // Add property info for that property id, as well, in the mock organization. - addPropertyData((MockOrganization) controllerTester.controller().organization(), "1234"); + registerContact(1234); + // POST (add) a tenant with property ID tester.assertResponse(request("/application/v4/tenant/tenant2", POST) .userIdentity(USER_ID) @@ -164,9 +177,10 @@ public class ApplicationApiTest extends ControllerContainerTest { .nToken(N_TOKEN) .data("{\"athensDomain\":\"domain2\", \"property\":\"property2\", \"propertyId\":\"1234\"}"), new File("tenant-without-applications-with-id.json")); - // GET a tenant with property ID + // GET a tenant with property ID and contact information + updateContactInformation(); tester.assertResponse(request("/application/v4/tenant/tenant2", GET).userIdentity(USER_ID), - new File("tenant-without-applications-with-id.json")); + new File("tenant-with-contact-info.json")); // POST (create) an application tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1", POST) @@ -316,6 +330,9 @@ public class ApplicationApiTest extends ControllerContainerTest { .recursive("true"), new File("application1-recursive.json")); + // GET logs + tester.assertResponse(request("/application/v4/tenant/tenant2/application//application1/environment/prod/region/corp-us-east-1/instance/default/logs", GET).userIdentity(USER_ID), new File("logs.json")); + // DELETE (cancel) ongoing change tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/deploying", DELETE) .userIdentity(HOSTED_VESPA_OPERATOR), @@ -465,8 +482,6 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void testDeployDirectly() { // Setup - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); tester.computeVersionStatus(); createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, USER_ID); @@ -500,8 +515,6 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void testDeployDirectlyUsingOneCallForDeploy() { // Setup - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); tester.computeVersionStatus(); UserId userId = new UserId("new_user"); createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, userId); @@ -523,10 +536,8 @@ public class ApplicationApiTest extends ControllerContainerTest { } @Test - public void testSortsDeploymentsAndJobs() throws Exception { + public void testSortsDeploymentsAndJobs() { // Setup - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); tester.computeVersionStatus(); createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, USER_ID); @@ -602,7 +613,6 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void testErrorResponses() throws Exception { - ContainerTester tester = new ContainerTester(container, responseFiles); tester.computeVersionStatus(); createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, USER_ID); @@ -749,7 +759,7 @@ public class ApplicationApiTest extends ControllerContainerTest { // Create legancy tenant name containing underscores tester.controller().tenants().create(new AthenzTenant(TenantName.from("my_tenant"), ATHENZ_TENANT_DOMAIN, - new Property("property1"), Optional.empty()), + new Property("property1"), Optional.empty(), Optional.empty()), N_TOKEN); // POST (add) a Athenz tenant with dashes duplicates existing one with underscores tester.assertResponse(request("/application/v4/tenant/my-tenant", POST) @@ -761,8 +771,7 @@ public class ApplicationApiTest extends ControllerContainerTest { } @Test - public void testAuthorization() throws Exception { - ContainerTester tester = new ContainerTester(container, responseFiles); + public void testAuthorization() { UserId authorizedUser = USER_ID; UserId unauthorizedUser = new UserId("othertenant"); @@ -855,9 +864,7 @@ public class ApplicationApiTest extends ControllerContainerTest { } @Test - public void deployment_fails_on_illegal_domain_in_deployment_spec() throws IOException { - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); + public void deployment_fails_on_illegal_domain_in_deployment_spec() { ApplicationPackage applicationPackage = new ApplicationPackageBuilder() .upgradePolicy("default") .athenzIdentity(com.yahoo.config.provision.AthenzDomain.from("invalid.domain"), com.yahoo.config.provision.AthenzService.from("service")) @@ -881,8 +888,6 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void deployment_succeeds_when_correct_domain_is_used() { - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); ApplicationPackage applicationPackage = new ApplicationPackageBuilder() .upgradePolicy("default") .athenzIdentity(com.yahoo.config.provision.AthenzDomain.from("domain1"), com.yahoo.config.provision.AthenzService.from("service")) @@ -912,11 +917,10 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void testJobStatusReporting() { - ContainerControllerTester tester = new ContainerControllerTester(container, responseFiles); addUserToHostedOperatorRole(HostedAthenzIdentities.from(HOSTED_VESPA_OPERATOR)); - tester.containerTester().computeVersionStatus(); + tester.computeVersionStatus(); long projectId = 1; - Application app = tester.createApplication(); + Application app = controllerTester.createApplication(); ApplicationPackage applicationPackage = new ApplicationPackageBuilder() .environment(Environment.prod) .region("corp-us-east-1") @@ -924,11 +928,11 @@ public class ApplicationApiTest extends ControllerContainerTest { Version vespaVersion = new Version("6.1"); // system version from mock config server client - BuildJob job = new BuildJob(report -> notifyCompletion(report, tester), tester.artifactRepository()) + BuildJob job = new BuildJob(report -> notifyCompletion(report, controllerTester), controllerTester.artifactRepository()) .application(app) .projectId(projectId); job.type(JobType.component).uploadArtifact(applicationPackage).submit(); - tester.deploy(app, applicationPackage, TEST_ZONE); + controllerTester.deploy(app, applicationPackage, TEST_ZONE); job.type(JobType.systemTest).submit(); // Notifying about unknown job fails @@ -936,7 +940,7 @@ public class ApplicationApiTest extends ControllerContainerTest { .data(asJson(job.type(JobType.productionUsEast3).report())) .userIdentity(HOSTED_VESPA_OPERATOR) .get(); - tester.containerTester().assertResponse(request, new File("jobreport-unexpected-completion.json"), 400); + tester.assertResponse(request, new File("jobreport-unexpected-completion.json"), 400); // ... and assert it was recorded JobStatus recordedStatus = @@ -960,25 +964,24 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void testJobStatusReportingOutOfCapacity() { - ContainerControllerTester tester = new ContainerControllerTester(container, responseFiles); - tester.containerTester().computeVersionStatus(); + controllerTester.containerTester().computeVersionStatus(); long projectId = 1; - Application app = tester.createApplication(); + Application app = controllerTester.createApplication(); ApplicationPackage applicationPackage = new ApplicationPackageBuilder() .environment(Environment.prod) .region("corp-us-east-1") .build(); // Report job failing with out of capacity - BuildJob job = new BuildJob(report -> notifyCompletion(report, tester), tester.artifactRepository()) + BuildJob job = new BuildJob(report -> notifyCompletion(report, controllerTester), controllerTester.artifactRepository()) .application(app) .projectId(projectId); job.type(JobType.component).uploadArtifact(applicationPackage).submit(); - tester.deploy(app, applicationPackage, TEST_ZONE); + controllerTester.deploy(app, applicationPackage, TEST_ZONE); job.type(JobType.systemTest).submit(); - tester.deploy(app, applicationPackage, STAGING_ZONE); + controllerTester.deploy(app, applicationPackage, STAGING_ZONE); job.type(JobType.stagingTest).error(DeploymentJobs.JobError.outOfCapacity).submit(); // Appropriate error is recorded @@ -1134,7 +1137,7 @@ public class ApplicationApiTest extends ControllerContainerTest { private void startAndTestChange(ContainerControllerTester controllerTester, ApplicationId application, long projectId, ApplicationPackage applicationPackage, - HttpEntity deployData, long buildNumber) throws IOException { + HttpEntity deployData, long buildNumber) { ContainerTester tester = controllerTester.containerTester(); // Trigger application change @@ -1208,11 +1211,24 @@ public class ApplicationApiTest extends ControllerContainerTest { } } - private void addPropertyData(MockOrganization organization, String propertyIdValue) { - PropertyId propertyId = new PropertyId(propertyIdValue); - organization.addProperty(propertyId); - organization.setContactsFor(propertyId, Arrays.asList(Collections.singletonList(User.from("alice")), - Collections.singletonList(User.from("bob")))); + private MockOrganization organization() { + return (MockOrganization) tester.container().components().getComponent(MockOrganization.class.getName()); + } + + private void updateContactInformation() { + new ContactInformationMaintainer(tester.controller(), Duration.ofDays(1), + new JobControl(tester.controller().curator()), + organization()).run(); + } + + private void registerContact(long propertyId) { + PropertyId p = new PropertyId(String.valueOf(propertyId)); + organization().addProperty(p) + .setIssueUrl(p, URI.create("www.issues.tld/" + p.id())) + .setContactsUrl(p, URI.create("www.contacts.tld/" + p.id())) + .setPropertyUrl(p, URI.create("www.properties.tld/" + p.id())) + .setContactsFor(p, Arrays.asList(Collections.singletonList(User.from("alice")), + Collections.singletonList(User.from("bob")))); } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelperTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelperTest.java index 01f9ea9dfa0..4c8cf0e7784 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelperTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelperTest.java @@ -2,11 +2,7 @@ package com.yahoo.vespa.hosted.controller.restapi.application; import com.yahoo.component.Version; import com.yahoo.container.jdisc.HttpResponse; -import com.yahoo.vespa.hosted.controller.api.application.v4.model.configserverbindings.ConfigChangeActions; -import com.yahoo.vespa.hosted.controller.api.application.v4.model.configserverbindings.RefeedAction; -import com.yahoo.vespa.hosted.controller.api.identifiers.DeploymentId; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServerException; -import com.yahoo.vespa.hosted.controller.api.integration.deployment.TesterCloud; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.deployment.InternalDeploymentTester; @@ -32,14 +28,10 @@ import static com.yahoo.vespa.hosted.controller.api.integration.deployment.JobTy import static com.yahoo.vespa.hosted.controller.api.integration.deployment.TesterCloud.Status.FAILURE; import static com.yahoo.vespa.hosted.controller.deployment.InternalDeploymentTester.appId; import static com.yahoo.vespa.hosted.controller.deployment.JobController.testerOf; -import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.aborted; import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.deploymentFailed; -import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.error; import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.installationFailed; import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.running; import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.testFailure; -import static java.util.Collections.emptyList; -import static java.util.Collections.singletonList; import static org.junit.Assert.assertEquals; /** @@ -109,7 +101,7 @@ public class JobControllerApiHandlerHelperTest { tester.tester().upgradeSystem(platform); // us-central-1 has started, deployed, and is installing. Deployment is not yet verified. - // us-east-3 is pending the failed staging test, while us-east-3 is pending us-central-1. + // us-east-3 is waiting for the failed staging test and us-central-1, while us-west-1 is waiting only for us-central-1. // Only us-east-3 is verified, on revision1. // staging-test has 4 runs: one success without sources on revision1, one success from revision1 to revision2, // one success from revision2 to revision3 and one failure from revision1 to revision3. diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json index ee54e2741ba..07a3dbb7f95 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json @@ -1,6 +1,7 @@ { "application": "application1", "instance": "default", + "deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/", "deploymentJobs": [ { "type": "component", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json index c93ff6a0dd2..0d7607f1df6 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json @@ -1,6 +1,7 @@ { "application": "application1", "instance": "default", + "deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/", "deploying": { "revision": { "hash": "(ignore)", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json index a1bd96e46d2..4e4a870662a 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json @@ -1,6 +1,7 @@ { "application": "application1", "instance": "default", + "deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/", "deploying": { "revision": { "hash": "(ignore)", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json index fa51d645cfc..837c46aaec1 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json @@ -1,6 +1,7 @@ { "application": "application2", "instance": "default", + "deployments": "http://localhost:8080/application/v4/tenant/tenant2/application/application2/instance/default/job/", "deploying": { "version": "(ignore)" }, diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/logs.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/logs.json new file mode 100644 index 00000000000..398a62758ee --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/logs.json @@ -0,0 +1,5 @@ +{ + "subfolder": { + "log2.log":"VGhpcyBpcyBhbm90aGVyIGxvZyBmaWxl"}, + "log1.log":"VGhpcyBpcyBvbmUgbG9nIGZpbGU=" +}
\ No newline at end of file diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/overview.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/overview.json index b7a7dcbf796..4466872022e 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/overview.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/overview.json @@ -1,17 +1,19 @@ { "lastVersions": { "platform": { - "version": "7.1", + "platform": "7.1", "at": 0, "pending": "Waiting for current deployment to complete" }, "application": { - "version": { - "id": "1.0.3-commit1", + "application": { + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "at": 2000, "deploying": "0 of 3 complete" @@ -19,11 +21,13 @@ }, "deploying": { "application": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } } }, "deployments": [ @@ -32,11 +36,13 @@ "at": 2000, "platform": "6.1", "application": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "verified": false, "status": "verifying" @@ -47,11 +53,13 @@ "at": 1000, "platform": "6.1", "application": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "verified": false, "status": "pending" @@ -60,11 +68,13 @@ "at": 0, "platform": "6.1", "application": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "verified": true, "status": "pending" @@ -81,19 +91,23 @@ "end": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -109,19 +123,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -137,11 +155,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -159,19 +179,23 @@ "status": "pending", "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "cooldown": "failed" @@ -185,19 +209,23 @@ "end": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": {}, "log": "https://some.url:43/root/staging-test/run/4" @@ -209,19 +237,23 @@ "end": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -237,19 +269,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -265,11 +301,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -281,7 +319,7 @@ ], "url": "https://some.url:43/root/staging-test" }, - "production-us-central-1": { + "us-central-1": { "runs": [ { "id": 3, @@ -289,19 +327,23 @@ "start": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -316,19 +358,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -344,11 +390,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -360,28 +408,32 @@ ], "url": "https://some.url:43/root/production-us-central-1" }, - "production-us-west-1": { + "us-west-1": { "runs": [ { "status": "pending", "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { - "production-us-central-1": "running" + "us-central-1": "running" } }, { @@ -391,19 +443,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -419,11 +475,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -435,29 +493,33 @@ ], "url": "https://some.url:43/root/production-us-west-1" }, - "production-us-east-3": { + "us-east-3": { "runs": [ { "status": "pending", "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "staging-test": "failed", - "production-us-central-1": "running" + "us-central-1": "running" } }, { @@ -467,19 +529,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "failed" @@ -493,11 +559,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/staging-runs.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/staging-runs.json index 448411b3912..8c5e5253482 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/staging-runs.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/staging-runs.json @@ -6,11 +6,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "steps": { "deployInitialReal": "succeeded", @@ -39,19 +41,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "steps": { "deployInitialReal": "succeeded", @@ -80,19 +86,23 @@ "end": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "steps": { "deployInitialReal": "succeeded", @@ -121,19 +131,23 @@ "end": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "steps": { "deployInitialReal": "succeeded", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-contact-info.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-contact-info.json new file mode 100644 index 00000000000..0ba0a01c5d0 --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-contact-info.json @@ -0,0 +1,19 @@ +{ + "tenant": "tenant2", + "type": "ATHENS", + "athensDomain": "domain2", + "property": "property2", + "propertyId": "1234", + "applications": [], + "propertyUrl": "www.properties.tld/1234", + "contactsUrl": "www.contacts.tld/1234", + "issueCreationUrl": "www.issues.tld/1234", + "contacts": [ + [ + "alice" + ], + [ + "bob" + ] + ] +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-without-applications-with-id.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-without-applications-with-id.json index 69949c47d8c..5624150463a 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-without-applications-with-id.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-without-applications-with-id.json @@ -4,18 +4,5 @@ "athensDomain": "domain2", "property": "property2", "propertyId": "1234", - "applications": [ - - ], - "propertyUrl": "www.properties.tld/1234", - "contactsUrl": "www.contacts.tld/1234", - "issueCreationUrl": "www.issues.tld/1234", - "contacts": [ - [ - "alice" - ], - [ - "bob" - ] - ] + "applications": [] } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json index 2b847010482..6a71e524ae4 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json @@ -10,6 +10,9 @@ "name": "ClusterUtilizationMaintainer" }, { + "name": "ContactInformationMaintainer" + }, + { "name": "DefaultOsUpgrader" }, { diff --git a/controller-server/src/test/resources/test_runner_services.xml-cd b/controller-server/src/test/resources/test_runner_services.xml-cd index 9c6cfe6fe2d..e0fca9716eb 100644 --- a/controller-server/src/test/resources/test_runner_services.xml-cd +++ b/controller-server/src/test/resources/test_runner_services.xml-cd @@ -37,6 +37,6 @@ </filtering> </http> - <nodes count="1" flavor="d-2-8-50" /> + <nodes count="1" flavor="d-1-4-50" /> </container> </services> diff --git a/docker-api/pom.xml b/docker-api/pom.xml index 64410c32f06..74e463ef157 100644 --- a/docker-api/pom.xml +++ b/docker-api/pom.xml @@ -18,16 +18,20 @@ <name>${project.artifactId}</name> <dependencies> + <!-- Provided --> <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>container-dev</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> + + <!-- Compile --> <dependency> <groupId>com.github.docker-java</groupId> <artifactId>docker-java</artifactId> <version>3.0.13</version> + <scope>compile</scope> <exclusions> <exclusion> <groupId>org.slf4j</groupId> @@ -92,6 +96,7 @@ <dependency> <groupId>net.jpountz.lz4</groupId> <artifactId>lz4</artifactId> + <scope>compile</scope> </dependency> <dependency> <groupId>org.apache.httpcomponents</groupId> @@ -100,6 +105,7 @@ docker-java so the dependency is declared closer to the root of maven and more likely be the version that is finally being used. --> <version>4.4.1</version> + <scope>compile</scope> </dependency> <dependency> <groupId>org.apache.httpcomponents</groupId> @@ -108,7 +114,10 @@ docker-java so the dependency is declared closer to the root of maven and more likely be the version that is finally being used. --> <version>4.5</version> + <scope>compile</scope> </dependency> + + <!-- Test --> <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> @@ -121,7 +130,6 @@ </dependency> </dependencies> - <build> <plugins> <plugin> diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/ContainerStatsImpl.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/ContainerStatsImpl.java index f0419a36d46..a56c1e41a51 100644 --- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/ContainerStatsImpl.java +++ b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/ContainerStatsImpl.java @@ -1,12 +1,14 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.dockerapi; +import java.util.Collections; import java.util.Map; +import java.util.Optional; /** * Wrapper class for {@link com.github.dockerjava.api.model.Statistics} to prevent leaking from docker-java library. * - * @author valerijf + * @author freva */ public class ContainerStatsImpl implements Docker.ContainerStats { private final Map<String, Object> networks; @@ -16,7 +18,8 @@ public class ContainerStatsImpl implements Docker.ContainerStats { public ContainerStatsImpl(Map<String, Object> networks, Map<String, Object> cpuStats, Map<String, Object> memoryStats, Map<String, Object> blkioStats) { - this.networks = networks; + // Network stats are null when container uses host network + this.networks = Optional.ofNullable(networks).orElse(Collections.emptyMap()); this.cpuStats = cpuStats; this.memoryStats = memoryStats; this.blkioStats = blkioStats; diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImpl.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImpl.java index 260e2da7c59..d95f7b7b8e1 100644 --- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImpl.java +++ b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImpl.java @@ -107,7 +107,14 @@ class CreateContainerCommandImpl implements Docker.CreateContainerCommand { @Override public Docker.CreateContainerCommand withVolume(String path, String volumePath) { assert path.indexOf(':') == -1; - volumeBindSpecs.add(path + ":" + volumePath); + volumeBindSpecs.add(path + ":" + volumePath + ":Z"); + return this; + } + + @Override + public Docker.CreateContainerCommand withSharedVolume(String path, String volumePath) { + assert path.indexOf(':') == -1; + volumeBindSpecs.add(path + ":" + volumePath + ":z"); return this; } diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/Docker.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/Docker.java index 91d5125eba3..5e8a0feb099 100644 --- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/Docker.java +++ b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/Docker.java @@ -19,7 +19,30 @@ public interface Docker { interface CreateContainerCommand { CreateContainerCommand withLabel(String name, String value); CreateContainerCommand withEnvironment(String name, String value); + + /** + * Mounts a directory on host inside the docker container. + * + * <p>Bind mount content will be <b>private</b> to this container (and host) only. + * + * <p>When using this method and selinux is enabled (/usr/sbin/sestatus), starting + * multiple containers which mount host's /foo directory into the container, will make + * /foo's content visible/readable/writable only inside the container which was last + * started and on the host. All the other containers will get "Permission denied". + * + * <p>Use {@link #withSharedVolume(String, String)} to mount a given host directory + * into multiple containers. + */ CreateContainerCommand withVolume(String path, String volumePath); + + /** + * Mounts a directory on host inside the docker container. + * + * <p>The bind mount content will be <b>shared</b> among multiple containers. + * + * @see #withVolume(String, String) + */ + CreateContainerCommand withSharedVolume(String path, String volumePath); CreateContainerCommand withNetworkMode(String mode); CreateContainerCommand withIpAddress(InetAddress address); CreateContainerCommand withUlimit(String name, int softLimit, int hardLimit); diff --git a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImplTest.java b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImplTest.java index 0d8701ac43c..5ce8c6b093c 100644 --- a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImplTest.java +++ b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImplTest.java @@ -46,7 +46,7 @@ public class CreateContainerCommandImplTest { "--ulimit nproc=10:20 " + "--env env1=val1 " + "--env env2=val2 " + - "--volume vol1:/host/vol1 " + + "--volume vol1:/host/vol1:Z " + "--cap-add SYS_ADMIN " + "--cap-add SYS_PTRACE " + "--cap-drop NET_ADMIN " + diff --git a/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java b/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java index 80f1f003412..0f3f3938701 100644 --- a/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java +++ b/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java @@ -34,7 +34,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Logger; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class MbusRequestContext implements RequestContext, ResponseHandler { diff --git a/document/src/main/java/com/yahoo/document/DataType.java b/document/src/main/java/com/yahoo/document/DataType.java index abdbf394591..3f34314f0de 100644 --- a/document/src/main/java/com/yahoo/document/DataType.java +++ b/document/src/main/java/com/yahoo/document/DataType.java @@ -246,9 +246,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com } public boolean equals(Object other) { - if (!(other instanceof DataType)) return false; - DataType type = (DataType)other; - return (name.equals(type.name) && dataTypeId == type.dataTypeId); + return (other instanceof DataType) && (dataTypeId == ((DataType)other).dataTypeId); } public String toString() { diff --git a/document/src/main/java/com/yahoo/document/DocumentTypeManagerConfigurer.java b/document/src/main/java/com/yahoo/document/DocumentTypeManagerConfigurer.java index 4ae5e6a713c..7678360ea30 100644 --- a/document/src/main/java/com/yahoo/document/DocumentTypeManagerConfigurer.java +++ b/document/src/main/java/com/yahoo/document/DocumentTypeManagerConfigurer.java @@ -70,9 +70,7 @@ public class DocumentTypeManagerConfigurer implements ConfigSubscriber.SingleSub log.log(LogLevel.DEBUG, "Configuring document manager with " + config.datatype().size() + " data types."); ArrayList<DocumentmanagerConfig.Datatype> failed = new ArrayList<>(); failed.addAll(config.datatype()); - int failCounter = 30; while (!failed.isEmpty()) { - --failCounter; ArrayList<DocumentmanagerConfig.Datatype> tmp = failed; failed = new ArrayList<>(); for (int i = 0; i < tmp.size(); i++) { @@ -82,9 +80,6 @@ public class DocumentTypeManagerConfigurer implements ConfigSubscriber.SingleSub registerTypeIdMapping(config, manager, thisDataType, id); } catch (IllegalArgumentException e) { failed.add(thisDataType); - if (failCounter < 0) { - throw e; - } } } } diff --git a/document/src/main/java/com/yahoo/document/DocumentUpdate.java b/document/src/main/java/com/yahoo/document/DocumentUpdate.java index 70c5410534e..ad93942c1c0 100644 --- a/document/src/main/java/com/yahoo/document/DocumentUpdate.java +++ b/document/src/main/java/com/yahoo/document/DocumentUpdate.java @@ -7,6 +7,7 @@ import com.yahoo.document.serialization.DocumentSerializerFactory; import com.yahoo.document.serialization.DocumentUpdateReader; import com.yahoo.document.serialization.DocumentUpdateWriter; import com.yahoo.document.update.AssignValueUpdate; +import com.yahoo.document.update.ClearValueUpdate; import com.yahoo.document.update.FieldUpdate; import com.yahoo.document.update.ValueUpdate; import com.yahoo.io.GrowableByteBuffer; @@ -137,9 +138,20 @@ public class DocumentUpdate extends DocumentOperation implements Iterable<FieldP ValueUpdate last = update.getValueUpdate(update.size() - 1); if (last instanceof AssignValueUpdate) { FieldValue currentValue = doc.getFieldValue(update.getField()); - if ((currentValue != null) && (currentValue.compareTo(last.getValue()) == 0)) { + if ((currentValue != null) && currentValue.equals(last.getValue())) { iter.remove(); } + } else if (last instanceof ClearValueUpdate) { + FieldValue currentValue = doc.getFieldValue(update.getField()); + if (currentValue == null) { + iter.remove(); + } else { + FieldValue copy = currentValue.clone(); + copy.clear(); + if (currentValue.equals(copy)) { + iter.remove(); + } + } } } } diff --git a/document/src/main/java/com/yahoo/document/ReferenceDataType.java b/document/src/main/java/com/yahoo/document/ReferenceDataType.java index 5b5ba256f43..115917c4118 100644 --- a/document/src/main/java/com/yahoo/document/ReferenceDataType.java +++ b/document/src/main/java/com/yahoo/document/ReferenceDataType.java @@ -75,6 +75,7 @@ public class ReferenceDataType extends DataType { "type in ReferenceDataType instance (type is '%s')", this.targetType.getName())); } this.targetType = targetType; + setName(buildTypeName(targetType)); } @Override @@ -98,4 +99,21 @@ public class ReferenceDataType extends DataType { ReferenceFieldValue rhs = (ReferenceFieldValue)value; return rhs.getDataType().equals(this); } + + private int compareTargetType(DataType rhs) { + return (rhs instanceof ReferenceDataType) ? targetType.compareTo(((ReferenceDataType) rhs).targetType) : 0; + } + + @Override + public int compareTo(DataType rhs) { + int cmp = super.compareTo(rhs); + return (cmp != 0) ? cmp : compareTargetType(rhs); + } + + @Override + public boolean equals(Object rhs) { + return super.equals(rhs) + && (rhs instanceof ReferenceDataType) + && targetType.equals(((ReferenceDataType) rhs).targetType); + } } diff --git a/document/src/main/java/com/yahoo/document/datatypes/Array.java b/document/src/main/java/com/yahoo/document/datatypes/Array.java index 8f6b68fcc38..660e58efa25 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/Array.java +++ b/document/src/main/java/com/yahoo/document/datatypes/Array.java @@ -11,7 +11,14 @@ import com.yahoo.document.serialization.FieldWriter; import com.yahoo.document.serialization.XmlSerializationHelper; import com.yahoo.document.serialization.XmlStream; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.ListIterator; +import java.util.RandomAccess; /** * FieldValue which encapsulates a Array value diff --git a/document/src/main/java/com/yahoo/document/datatypes/MapFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/MapFieldValue.java index 6d6c18755c1..4777622be1f 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/MapFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/MapFieldValue.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document.datatypes; -import com.yahoo.collections.CollectionComparator; import com.yahoo.document.DataType; import com.yahoo.document.Field; import com.yahoo.document.FieldPath; @@ -10,7 +9,17 @@ import com.yahoo.document.serialization.FieldReader; import com.yahoo.document.serialization.FieldWriter; import com.yahoo.document.serialization.XmlSerializationHelper; import com.yahoo.document.serialization.XmlStream; -import java.util.*; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.Map; +import java.util.HashMap; +import java.util.Set; +import java.util.List; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; + /** * Vespa map. Backed by and and parametrized by FieldValue @@ -72,10 +81,7 @@ public class MapFieldValue<K extends FieldValue, V extends FieldValue> extends C */ public boolean equals(Object o) { if (!(o instanceof MapFieldValue)) return false; - MapFieldValue otherSet = (MapFieldValue) o; - Map<K, V> map1 = values; - Map<K, V> map2 = otherSet.values; - return (super.equals(o) && map1.equals(map2)); + return super.equals(o) && entrySet().equals(((MapFieldValue) o).entrySet()); } @Override @@ -276,14 +282,24 @@ public class MapFieldValue<K extends FieldValue, V extends FieldValue> extends C return comp; } //types are equal, this must be of this type - MapFieldValue otherValue = (MapFieldValue) fieldValue; - comp = CollectionComparator.compare(values.keySet(), otherValue.values.keySet()); - - if (comp != 0) { - return comp; + MapFieldValue<K,V> rhs = (MapFieldValue<K,V>) fieldValue; + if (size() < rhs.size()) { + return -1; + } else if (size() > rhs.size()) { + return 1; + } + Map.Entry<K,V> [] entries = entrySet().toArray(new Map.Entry[size()]); + Map.Entry<K,V> [] rhsEntries = rhs.entrySet().toArray(new Map.Entry[rhs.size()]); + Arrays.sort(entries, Comparator.comparing(Map.Entry<K,V>::getKey)); + Arrays.sort(rhsEntries, Comparator.comparing(Map.Entry<K,V>::getKey)); + for (int i = 0; i < entries.length; i++) { + comp = entries[i].getKey().compareTo(rhsEntries[i].getKey()); + if (comp != 0) return comp; + comp = entries[i].getValue().compareTo(rhsEntries[i].getValue()); + if (comp != 0) return comp; } - return CollectionComparator.compare(values.values(), otherValue.values.values()); + return 0; } /** diff --git a/document/src/main/java/com/yahoo/document/datatypes/ReferenceFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/ReferenceFieldValue.java index eb8ce4d3b24..0097e5c93d0 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/ReferenceFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/ReferenceFieldValue.java @@ -162,4 +162,9 @@ public class ReferenceFieldValue extends FieldValue { public DocumentId getWrappedValue() { return documentId.orElse(null); } + + @Override + public String toString() { + return documentId.toString(); + } } diff --git a/document/src/main/java/com/yahoo/document/datatypes/WeightedSet.java b/document/src/main/java/com/yahoo/document/datatypes/WeightedSet.java index 0e4c56406f0..d505380523f 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/WeightedSet.java +++ b/document/src/main/java/com/yahoo/document/datatypes/WeightedSet.java @@ -1,14 +1,25 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document.datatypes; -import com.yahoo.collections.CollectionComparator; -import com.yahoo.document.*; +import com.yahoo.document.DataType; +import com.yahoo.document.Field; +import com.yahoo.document.WeightedSetDataType; +import com.yahoo.document.MapDataType; +import com.yahoo.document.FieldPath; import com.yahoo.document.serialization.FieldReader; import com.yahoo.document.serialization.FieldWriter; import com.yahoo.document.serialization.XmlSerializationHelper; import com.yahoo.document.serialization.XmlStream; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; + /** * A weighted set, a unique set of keys with an associated integer weight. This class @@ -241,8 +252,7 @@ public final class WeightedSet<K extends FieldValue> extends CollectionFieldValu */ public boolean equals(Object o) { if (!(o instanceof WeightedSet)) return false; - WeightedSet otherSet = (WeightedSet) o; - return (super.equals(o) && map.equals(otherSet.map)); + return (super.equals(o) && map.equals(((WeightedSet<K>)o).map)); } /** @@ -293,15 +303,7 @@ public final class WeightedSet<K extends FieldValue> extends CollectionFieldValu return comp; } - //types are equal, this must be of this type - WeightedSet otherValue = (WeightedSet) fieldValue; - comp = CollectionComparator.compare(map.keySet(), otherValue.map.keySet()); - - if (comp != 0) { - return comp; - } - - return CollectionComparator.compare(map.values(), otherValue.map.values()); + return map.compareTo(((WeightedSet<K>)fieldValue).map); } diff --git a/document/src/test/document/documentmanager.replaced_temporary.cfg b/document/src/test/document/documentmanager.replaced_temporary.cfg new file mode 100644 index 00000000000..06a5364cfb1 --- /dev/null +++ b/document/src/test/document/documentmanager.replaced_temporary.cfg @@ -0,0 +1,55 @@ +enablecompression false +datatype[12].id 1191756824 +datatype[12].referencetype[0].target_type_id 886149367 +datatype[13].id -2054976470 +datatype[13].arraytype[0].datatype 5 +datatype[14].id 959075962 +datatype[14].structtype[0].name "ad.header" +datatype[14].structtype[0].version 0 +datatype[14].structtype[0].compresstype "NONE" +datatype[14].structtype[0].compresslevel 0 +datatype[14].structtype[0].compressthreshold 95 +datatype[14].structtype[0].compressminsize 800 +datatype[14].structtype[0].field[286].datatype 1191756824 +datatype[14].structtype[0].field[286].name "campaign_ref" +datatype[14].structtype[0].field[286].detailedtype "" +datatype[15].id -255288561 +datatype[15].structtype[0].name "ad.body" +datatype[15].structtype[0].version 0 +datatype[15].structtype[0].compresstype "NONE" +datatype[15].structtype[0].compresslevel 0 +datatype[15].structtype[0].compressthreshold 95 +datatype[15].structtype[0].compressminsize 800 +datatype[16].id 2987301 +datatype[16].documenttype[0].name "ad" +datatype[16].documenttype[0].version 0 +datatype[16].documenttype[0].inherits[0].name "document" +datatype[16].documenttype[0].inherits[0].version 0 +datatype[16].documenttype[0].headerstruct 959075962 +datatype[16].documenttype[0].bodystruct -255288561 +datatype[16].documenttype[0].fieldsets.[document].fields[722] "campaign_ref" +datatype[57].id 350014056 +datatype[57].structtype[0].name "mystiqueCampaign.header" +datatype[57].structtype[0].version 0 +datatype[57].structtype[0].compresstype "NONE" +datatype[57].structtype[0].compresslevel 0 +datatype[57].structtype[0].compressthreshold 95 +datatype[57].structtype[0].compressminsize 800 +datatype[57].structtype[0].field[0].datatype 4 +datatype[57].structtype[0].field[0].name "campaign_id" +datatype[57].structtype[0].field[0].detailedtype "" +datatype[58].id -524078467 +datatype[58].structtype[0].name "mystiqueCampaign.body" +datatype[58].structtype[0].version 0 +datatype[58].structtype[0].compresstype "NONE" +datatype[58].structtype[0].compresslevel 0 +datatype[58].structtype[0].compressthreshold 95 +datatype[58].structtype[0].compressminsize 800 +datatype[59].id 886149367 +datatype[59].documenttype[0].name "mystiqueCampaign" +datatype[59].documenttype[0].version 0 +datatype[59].documenttype[0].inherits[0].name "document" +datatype[59].documenttype[0].inherits[0].version 0 +datatype[59].documenttype[0].headerstruct 350014056 +datatype[59].documenttype[0].bodystruct -524078467 +datatype[59].documenttype[0].fieldsets.[document].fields[0] "campaign_id" diff --git a/document/src/test/java/com/yahoo/document/DocumentTypeManagerTestCase.java b/document/src/test/java/com/yahoo/document/DocumentTypeManagerTestCase.java index aa4f5211df7..6acae4f37c6 100644 --- a/document/src/test/java/com/yahoo/document/DocumentTypeManagerTestCase.java +++ b/document/src/test/java/com/yahoo/document/DocumentTypeManagerTestCase.java @@ -528,6 +528,16 @@ search annotationsimplicitstruct { } @Test + public void no_temporary_targets_in_references_or_names() { + DocumentTypeManager manager = createConfiguredManager("file:src/test/document/documentmanager.replaced_temporary.cfg"); + DocumentType docType = manager.getDocumentType("ad"); + Field f = docType.getField("campaign_ref"); + assertTrue(f.getDataType() instanceof ReferenceDataType); + assertFalse(((ReferenceDataType)f.getDataType()).getTargetType() instanceof TemporaryStructuredDataType); + assertEquals("Reference<mystiqueCampaign>", f.getDataType().getName()); + } + + @Test public void can_have_reference_type_pointing_to_own_document_type() { DocumentTypeManager manager = createConfiguredManager("file:src/test/document/documentmanager.selfreference.cfg"); diff --git a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java index 4f3d7d3b820..15319985591 100644 --- a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java +++ b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java @@ -709,6 +709,48 @@ public class DocumentUpdateTestCase { assertEquals(expected, doc.getFieldValue(field).getWrappedValue()); } + @Test + public void testThatClearCanBePrunedIfNoneExisting() { + Field field = docType.getField("strfoo"); + Document doc = createDocument(); + StringFieldValue expected = new StringFieldValue("some value"); + expected.clear(); + doc.setFieldValue(field, expected); + DocumentUpdate update = new DocumentUpdate(docType, new DocumentId(documentId)); + update.addFieldUpdate(FieldUpdate.createClearField(field)); + update.prune(doc); + assertEquals(0, update.size()); + update.applyTo(doc); + assertEquals(expected, doc.getFieldValue(field)); + } + + @Test + public void testThatClearCanBePrunedIfEmpty() { + Field field = docType.getField("strfoo"); + String expected = ""; + Document doc = createDocument(); + DocumentUpdate update = new DocumentUpdate(docType, new DocumentId(documentId)); + update.addFieldUpdate(FieldUpdate.createClearField(field)); + update.prune(doc); + assertEquals(0, update.size()); + update.applyTo(doc); + assertNull(doc.getFieldValue(field)); + } + + @Test + public void testThatClearCanBePrunedIfNoneExistingAndLast() { + Field field = docType.getField("strfoo"); + String expected = ""; + Document doc = createDocument(); + DocumentUpdate update = new DocumentUpdate(docType, new DocumentId(documentId)); + update.addFieldUpdate(FieldUpdate.createAssign(field, new StringFieldValue("some value"))); + update.addFieldUpdate(FieldUpdate.createClearField(field)); + update.prune(doc); + assertEquals(0, update.size()); + update.applyTo(doc); + assertNull(doc.getFieldValue(field)); + } + private static TensorFieldValue createTensorFieldValue(String tensor) { return new TensorFieldValue(Tensor.from(tensor)); } diff --git a/document/src/test/java/com/yahoo/document/datatypes/ArrayTestCase.java b/document/src/test/java/com/yahoo/document/datatypes/ArrayTestCase.java index d1c7ce8dcf6..ce250c67658 100755 --- a/document/src/test/java/com/yahoo/document/datatypes/ArrayTestCase.java +++ b/document/src/test/java/com/yahoo/document/datatypes/ArrayTestCase.java @@ -5,9 +5,14 @@ import com.yahoo.document.ArrayDataType; import com.yahoo.document.DataType; import org.junit.Test; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; @@ -269,4 +274,48 @@ public class ArrayTestCase { assertTrue(Arrays.equals(expected, array.toArray(new StringFieldValue[0]))); } + @Test + public void testEquals() { + Array<StringFieldValue> a = new Array<>(new ArrayDataType(DataType.STRING)); + a.add(new StringFieldValue("mumbo jumbo 1")); + a.add(new StringFieldValue("mumbo jumbo 2")); + Array<StringFieldValue> b = new Array<>(new ArrayDataType(DataType.STRING)); + b.add(new StringFieldValue("mumbo jumbo 1")); + b.add(new StringFieldValue("mumbo jumbo 2")); + assertEquals(a, b); + assertEquals(0, a.compareTo(b)); + assertEquals(0, b.compareTo(a)); + + b.clear(); + List<String> l = new ArrayList<>(); + l.add("mumbo jumbo 1"); + l.add("mumbo jumbo 2"); + b.assign(l); + assertEquals(a, b); + assertEquals(0, a.compareTo(b)); + assertEquals(0, b.compareTo(a)); + } + + @Test + public void testLess() { + Array<StringFieldValue> a = new Array<>(new ArrayDataType(DataType.STRING)); + a.add(new StringFieldValue("mumbo jumbo 1")); + a.add(new StringFieldValue("mumbo jumbo 3")); + Array<StringFieldValue> b = new Array<>(new ArrayDataType(DataType.STRING)); + b.add(new StringFieldValue("mumbo jumbo 1")); + b.add(new StringFieldValue("mumbo jumbo 2")); + assertNotEquals(a, b); + assertEquals(1, a.compareTo(b)); + assertEquals(-1, b.compareTo(a)); + + b.clear(); + List<String> l = new ArrayList<>(); + l.add("mumbo jumbo 1"); + l.add("mumbo jumbo 2"); + b.assign(l); + assertNotEquals(a, b); + assertEquals(1, a.compareTo(b)); + assertEquals(-1, b.compareTo(a)); + } + } diff --git a/document/src/test/java/com/yahoo/document/datatypes/ReferenceFieldValueTestCase.java b/document/src/test/java/com/yahoo/document/datatypes/ReferenceFieldValueTestCase.java index 2615ed6d442..f931c6682b7 100644 --- a/document/src/test/java/com/yahoo/document/datatypes/ReferenceFieldValueTestCase.java +++ b/document/src/test/java/com/yahoo/document/datatypes/ReferenceFieldValueTestCase.java @@ -159,6 +159,12 @@ public class ReferenceFieldValueTestCase { ReferenceFieldValue rhs = new ReferenceFieldValue(referenceTypeFoo(), docId("id:ns:foo::toad")); assertEquals(lhs, rhs); } + @Test + public void references_with_same_type_and_no_id_are_equal() { + ReferenceFieldValue lhs = new ReferenceFieldValue(referenceTypeFoo()); + ReferenceFieldValue rhs = new ReferenceFieldValue(referenceTypeFoo()); + assertEquals(lhs, rhs); + } @Test public void hash_code_takes_type_and_id_into_account() { @@ -202,4 +208,10 @@ public class ReferenceFieldValueTestCase { assertEquals(docId("id:ns:foo::toad"), idRef.getWrappedValue()); } + @Test + public void that_toString_provides_value() { + assertEquals("Optional.empty", new ReferenceFieldValue(referenceTypeFoo()).toString()); + assertEquals("Optional[id:ns:foo::toad]", new ReferenceFieldValue(referenceTypeFoo(), docId("id:ns:foo::toad")).toString()); + } + } diff --git a/document/src/test/java/com/yahoo/document/datatypes/WeightedSetTestCase.java b/document/src/test/java/com/yahoo/document/datatypes/WeightedSetTestCase.java index 3436c73feae..2c6e208f888 100644 --- a/document/src/test/java/com/yahoo/document/datatypes/WeightedSetTestCase.java +++ b/document/src/test/java/com/yahoo/document/datatypes/WeightedSetTestCase.java @@ -2,13 +2,14 @@ package com.yahoo.document.datatypes; import com.yahoo.document.DataType; -import com.yahoo.document.MapDataType; import org.junit.Test; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -21,6 +22,76 @@ import static org.junit.Assert.fail; public class WeightedSetTestCase { @Test + public void testEquals() { + WeightedSet<StringFieldValue> a = new WeightedSet<>(DataType.TAG); + a.put(new StringFieldValue("this is a test"), 5); + a.put(new StringFieldValue("this is a second test"), 7); + + WeightedSet<StringFieldValue> b = new WeightedSet<>(DataType.TAG); + b.put(new StringFieldValue("this is a second test"), 7); + b.put(new StringFieldValue("this is a test"), 5); + + assertEquals(a, b); + assertEquals(0, a.compareTo(b)); + assertEquals(0, b.compareTo(a)); + } + + @Test + public void testEqualsOnMixedPrimitiveAndFieldValues() { + WeightedSet<StringFieldValue> a = new WeightedSet<>(DataType.TAG); + a.put(new StringFieldValue("this is a test"), 5); + a.put(new StringFieldValue("this is a second test"), 7); + + WeightedSet<StringFieldValue> b = new WeightedSet<>(DataType.TAG); + Map<String, Integer> m = new HashMap<String, Integer>(); + m.put("this is a second test", 7); + m.put("this is a test", 5); + b.assign(m); + + assertEquals(a, b); + assertEquals(0, a.compareTo(b)); + assertEquals(0, b.compareTo(a)); + } + + @Test + public void testCompareTo() { + WeightedSet<StringFieldValue> a = new WeightedSet<>(DataType.TAG); + a.put(new StringFieldValue("this is a test"), 5); + a.put(new StringFieldValue("this is a second test"), 7); + + WeightedSet<StringFieldValue> b = new WeightedSet<>(DataType.TAG); + b.put(new StringFieldValue("this is a test"), 5); + + assertNotEquals(a, b); + assertEquals(1, a.compareTo(b)); + assertEquals(-1, b.compareTo(a)); + + b.clear(); + b.put(new StringFieldValue("this is a test"), 5); + b.put(new StringFieldValue("this is a third test"), 7); + + assertNotEquals(a, b); + assertEquals(-1, a.compareTo(b)); + assertEquals(1, b.compareTo(a)); + + b.clear(); + b.put(new StringFieldValue("this is a test"), 5); + b.put(new StringFieldValue("this is a second test"), 7); + + assertEquals(a, b); + assertEquals(0, a.compareTo(b)); + assertEquals(0, b.compareTo(a)); + + b.clear(); + b.put(new StringFieldValue("this is a test"), 5); + b.put(new StringFieldValue("this is a second test"), 6); + + assertNotEquals(a, b); + assertEquals(1, a.compareTo(b)); + assertEquals(-1, b.compareTo(a)); + } + + @Test public void testSet() { WeightedSet<StringFieldValue> wset = new WeightedSet<>(DataType.TAG); diff --git a/document/src/vespa/document/datatype/datatype.cpp b/document/src/vespa/document/datatype/datatype.cpp index aef155999a4..8d2721a4d9b 100644 --- a/document/src/vespa/document/datatype/datatype.cpp +++ b/document/src/vespa/document/datatype/datatype.cpp @@ -159,7 +159,7 @@ DataType::~DataType() = default; bool DataType::operator==(const DataType& other) const { - return _dataTypeId == other._dataTypeId && _name == other._name; + return _dataTypeId == other._dataTypeId; } bool diff --git a/document/src/vespa/document/datatype/referencedatatype.cpp b/document/src/vespa/document/datatype/referencedatatype.cpp index 6792d95909c..7b7c83c7fa6 100644 --- a/document/src/vespa/document/datatype/referencedatatype.cpp +++ b/document/src/vespa/document/datatype/referencedatatype.cpp @@ -41,4 +41,10 @@ void ReferenceDataType::onBuildFieldPath(FieldPath &, vespalib::stringref remain } +bool ReferenceDataType::operator==(const DataType &rhs) const { + return DataType::operator==(rhs) + && rhs.inherits(classId) + && (_targetDocType == static_cast<const ReferenceDataType &>(rhs)._targetDocType); +} + } // document diff --git a/document/src/vespa/document/datatype/referencedatatype.h b/document/src/vespa/document/datatype/referencedatatype.h index d5804d09835..5ca52f3ccb2 100644 --- a/document/src/vespa/document/datatype/referencedatatype.h +++ b/document/src/vespa/document/datatype/referencedatatype.h @@ -24,6 +24,8 @@ public: void print(std::ostream&, bool verbose, const std::string& indent) const override; ReferenceDataType* clone() const override; void onBuildFieldPath(FieldPath & path, vespalib::stringref remainingFieldName) const override; + + bool operator==(const DataType &type) const override; }; } // document diff --git a/document/src/vespa/document/datatype/structdatatype.cpp b/document/src/vespa/document/datatype/structdatatype.cpp index 3ccb08c32be..7c308202e3b 100644 --- a/document/src/vespa/document/datatype/structdatatype.cpp +++ b/document/src/vespa/document/datatype/structdatatype.cpp @@ -40,7 +40,7 @@ StructDataType::StructDataType(vespalib::stringref name, int32_t dataTypeId) _compressionConfig() { } -StructDataType::~StructDataType() { } +StructDataType::~StructDataType() = default; StructDataType* StructDataType::clone() const { diff --git a/document/src/vespa/document/datatype/structdatatype.h b/document/src/vespa/document/datatype/structdatatype.h index 4491ed68e01..42003d3b466 100644 --- a/document/src/vespa/document/datatype/structdatatype.h +++ b/document/src/vespa/document/datatype/structdatatype.h @@ -71,10 +71,10 @@ public: DECLARE_IDENTIFIABLE(StructDataType); private: - typedef vespalib::hash_map<vespalib::string, Field::SP> StringFieldMap; - typedef vespalib::hash_map<int32_t, Field::SP> IntFieldMap; - StringFieldMap _nameFieldMap; - IntFieldMap _idFieldMap; + using StringFieldMap = vespalib::hash_map<vespalib::string, Field::SP>; + using IntFieldMap = vespalib::hash_map<int32_t, Field::SP>; + StringFieldMap _nameFieldMap; + IntFieldMap _idFieldMap; CompressionConfig _compressionConfig; /** @return "" if not conflicting. Error message otherwise. */ diff --git a/document/src/vespa/document/repo/configbuilder.cpp b/document/src/vespa/document/repo/configbuilder.cpp index 45433c2a606..42b37104e04 100644 --- a/document/src/vespa/document/repo/configbuilder.cpp +++ b/document/src/vespa/document/repo/configbuilder.cpp @@ -2,8 +2,8 @@ #include "configbuilder.h" -namespace document { -namespace config_builder { +namespace document::config_builder { + int32_t createFieldId(const vespalib::string &name, int32_t type) { StructDataType dummy("dummy", type); Field f(name, dummy, true); @@ -63,5 +63,4 @@ DocumenttypesConfigBuilderHelper::document(int32_t id, const vespalib::string &n return DocTypeRep(_config.documenttype.back()); } -} // namespace config_builder -} // namespace document +} diff --git a/document/src/vespa/document/repo/configbuilder.h b/document/src/vespa/document/repo/configbuilder.h index 598c72f6358..c389fd3b09e 100644 --- a/document/src/vespa/document/repo/configbuilder.h +++ b/document/src/vespa/document/repo/configbuilder.h @@ -9,8 +9,7 @@ #include <vespa/vespalib/stllike/string.h> #include <cassert> -namespace document { -namespace config_builder { +namespace document::config_builder { class TypeOrId; @@ -143,6 +142,6 @@ public: ::document::DocumenttypesConfigBuilder &config() { return _config; } }; -} // namespace config_builder -} // namespace document + +} diff --git a/document/src/vespa/document/repo/documenttyperepo.cpp b/document/src/vespa/document/repo/documenttyperepo.cpp index 03b7660efbe..a320750e0d5 100644 --- a/document/src/vespa/document/repo/documenttyperepo.cpp +++ b/document/src/vespa/document/repo/documenttyperepo.cpp @@ -89,7 +89,7 @@ void Repo::inherit(const Repo &parent) { bool Repo::addDataType(const DataType &type) { const DataType *& data_type = _types[type.getId()]; if (data_type) { - if (*data_type == type) { + if ((*data_type == type) && (data_type->getName() == type.getName())) { return false; // Redefinition of identical type is ok. } throw IllegalArgumentException( diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java index 42753338d06..eb71b3cfe47 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java @@ -26,7 +26,7 @@ import java.util.logging.Logger; * The sessions are multithread safe. * * @author bratseth - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar Rosenvinge</a> + * @author Einar Rosenvinge */ public class MessageBusAsyncSession implements MessageBusSession, AsyncSession { diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java index 7423792693b..dbf68106e07 100755 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java @@ -20,7 +20,7 @@ import java.util.List; public class ANDPolicy implements DocumentProtocolRoutingPolicy { // A list of hops that are to always be selected when select() is invoked. - private final List<Hop> hops = new ArrayList<Hop>(); + private final List<Hop> hops = new ArrayList<>(); /** * Constructs a new AND policy that requires all recipients to be ok for it to merge their replies to an ok reply. diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java index 82679e17990..a5b3accac68 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java @@ -51,7 +51,7 @@ public class MessageTypePolicy implements DocumentProtocolRoutingPolicy, ConfigS @Override public void configure(MessagetyperouteselectorpolicyConfig cfg) { - Map<Integer, Route> h = new HashMap<Integer, Route>(); + Map<Integer, Route> h = new HashMap<>(); for (MessagetyperouteselectorpolicyConfig.Route selector : cfg.route()) { h.put(selector.messagetype(), Route.parse(selector.name())); } diff --git a/documentgen-test/etc/complex/music.sd b/documentgen-test/etc/complex/music.sd index 6dbb7b76862..b95edd2b4f3 100644 --- a/documentgen-test/etc/complex/music.sd +++ b/documentgen-test/etc/complex/music.sd @@ -46,12 +46,12 @@ search music { field sw1 type float { indexing { - input weight * 6 + input w1 + input w2 | summary; + input weight_src * 6 + input w1_src + input w2_src | summary; } } field didinteger type array<int> { - indexing: input did | split " " | attribute + indexing: input did | split " " | for_each { to_int } | attribute } rank-profile default { diff --git a/documentgen-test/etc/complex/music2.sd b/documentgen-test/etc/complex/music2.sd index e608225bf38..4cc9db0651e 100644 --- a/documentgen-test/etc/complex/music2.sd +++ b/documentgen-test/etc/complex/music2.sd @@ -51,12 +51,12 @@ search music2 { field sw1 type float { indexing { - input weight * 6 + input w1 + input w2 | summary; + input weight_src * 6 + input w1_src + input w2_src | summary; } } field didinteger type array<int> { - indexing: input did | split " " | attribute + indexing: input did | split " " | for_each { to_int } | attribute } rank-profile default { diff --git a/documentgen-test/etc/complex/video.sd b/documentgen-test/etc/complex/video.sd index fc7f58298c1..0749daa01aa 100644 --- a/documentgen-test/etc/complex/video.sd +++ b/documentgen-test/etc/complex/video.sd @@ -32,7 +32,7 @@ search video { field sw1 type float { indexing { - input weight * 6 + input w1 + input w2 | summary; + input weight_src * 6 + input w1_src + input w2_src | summary; } } diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml index 0421f0bbd01..3dc6e1eee79 100644 --- a/fat-model-dependencies/pom.xml +++ b/fat-model-dependencies/pom.xml @@ -75,6 +75,11 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>model-evaluation</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>metrics</artifactId> <version>${project.version}</version> </dependency> diff --git a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerFactory.java b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerFactory.java index d8ea45e716d..e8a3038639a 100644 --- a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerFactory.java +++ b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerFactory.java @@ -4,10 +4,13 @@ package com.yahoo.filedistribution.fileacquirer; /** * Hides the real file acquirer type from 3rd party developers. * Not intended to be used by 3rd parties. + * * @author Tony Vaagenes */ public class FileAcquirerFactory { + public static FileAcquirer create(String configId) { return new FileAcquirerImpl(configId); } + } diff --git a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerImpl.java b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerImpl.java index fca4b206fc9..ab0f7521e7e 100644 --- a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerImpl.java +++ b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerImpl.java @@ -23,12 +23,15 @@ import java.io.File; * @author Tony Vaagenes */ class FileAcquirerImpl implements FileAcquirer { + static final class FileDistributionErrorCode { + public static final int baseErrorCode = 0x10000; public static final int baseFileProviderErrorCode = baseErrorCode + 0x1000; public static final int fileReferenceDoesNotExists = baseFileProviderErrorCode; public static final int fileReferenceRemoved = fileReferenceDoesNotExists + 1; + } private static final Logger log = Logger.getLogger(FileAcquirerImpl.class.getName()); @@ -131,13 +134,10 @@ class FileAcquirerImpl implements FileAcquirer { * given file reference. File references are produced by the * config system. * - * @throws TimeoutException if the file or directory could not be - * retrieved in time. - * @throws FileReferenceDoesNotExistException if the file is no - * longer available (due to reloading of config). + * @throws TimeoutException if the file or directory could not be retrieved in time. + * @throws FileReferenceDoesNotExistException if the file is no longer available (due to reloading of config). */ - public File waitFor(FileReference fileReference, long timeout, TimeUnit timeUnit) - throws InterruptedException { + public File waitFor(FileReference fileReference, long timeout, TimeUnit timeUnit) throws InterruptedException { Timer timer = new Timer(timeout, timeUnit); do { Target target = connection.getTarget(timer); diff --git a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/MockFileAcquirer.java b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/MockFileAcquirer.java index 25732d2dcc8..1a8a05d0a53 100644 --- a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/MockFileAcquirer.java +++ b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/MockFileAcquirer.java @@ -14,8 +14,9 @@ import java.util.concurrent.TimeUnit; * @author Tony Vaagenes */ public abstract class MockFileAcquirer implements FileAcquirer { + /** Creates a FileAcquirer that always returns the given file. **/ - public static FileAcquirer returnFile(final File file) { + public static FileAcquirer returnFile(File file) { return new MockFileAcquirer() { @Override public File waitFor(FileReference fileReference, @@ -26,7 +27,7 @@ public abstract class MockFileAcquirer implements FileAcquirer { } /** Creates a FileAcquirer that maps from fileReference.value to a file. **/ - public static FileAcquirer returnFiles(final Map<String, File> files) { + public static FileAcquirer returnFiles(Map<String, File> files) { return new MockFileAcquirer() { @Override public File waitFor(FileReference fileReference, @@ -60,4 +61,5 @@ public abstract class MockFileAcquirer implements FileAcquirer { @Override public void shutdown() {} + } diff --git a/fnet/build/buildspec.xml b/fnet/build/buildspec.xml deleted file mode 100644 index 22ce5b93d4a..00000000000 --- a/fnet/build/buildspec.xml +++ /dev/null @@ -1,39 +0,0 @@ -<!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> -<BuildSpecification> - <Owner> - <OwnerName></OwnerName> - <OwnerEmail></OwnerEmail> - </Owner> - - <Dependencies os="all" arch="all"> - <dep package="common/fastos" version="0.0" /> - </Dependencies> - - <PreBuild os="all" arch="all"> - <configure path="src"> - <parameter value="--fastos-dir ${fbuild_install_dir}/fastos" /> - <parameter value="--install-dir ${fbuild_install_dir}/fnet" /> - <if feature="nodirectwrite"> - <parameter value="--disable-direct-write" /> - </if> - </configure> - </PreBuild> - - <Build os="all" arch="all"> - <make path="src" target="bootstrap" /> - </Build> - - <PostBuild os="all" arch="all"> - </PostBuild> - - <Test os="all" arch="all"> - </Test> - - <Install os="all" arch="all"> - <make path="src" target="install" /> - </Install> - - <Dist os="all" arch="all"> - </Dist> - -</BuildSpecification> diff --git a/fnet/src/examples/frt/rpc/rpc_callback_client.cpp b/fnet/src/examples/frt/rpc/rpc_callback_client.cpp index 801de59b515..7c6434e870a 100644 --- a/fnet/src/examples/frt/rpc/rpc_callback_client.cpp +++ b/fnet/src/examples/frt/rpc/rpc_callback_client.cpp @@ -26,7 +26,7 @@ RPC::Init(FRT_Supervisor *s) { FRT_ReflectionBuilder rb(s); //------------------------------------------------------------------- - rb.DefineMethod("prod", "", "", true, + rb.DefineMethod("prod", "", "", FRT_METHOD(RPC::Prod), this); //------------------------------------------------------------------- } @@ -45,6 +45,7 @@ MyApp::Main() printf("usage : rpc_server <connectspec>\n"); return 1; } + bool ok = true; RPC rpc; FRT_Supervisor orb; rpc.Init(&orb); @@ -63,6 +64,7 @@ MyApp::Main() printf("[error(%d): %s]\n", req->GetErrorCode(), req->GetErrorMessage()); + ok = false; } printf("invokeCnt: %d\n", rpc.invokeCnt); @@ -76,6 +78,7 @@ MyApp::Main() printf("[error(%d): %s]\n", req->GetErrorCode(), req->GetErrorMessage()); + ok = false; } printf("invokeCnt: %d\n", rpc.invokeCnt); @@ -89,14 +92,18 @@ MyApp::Main() printf("[error(%d): %s]\n", req->GetErrorCode(), req->GetErrorMessage()); + ok = false; } printf("invokeCnt: %d\n", rpc.invokeCnt); + if (rpc.invokeCnt != 3) { + ok = false; + } req->SubRef(); target->SubRef(); orb.ShutDown(true); - return 0; + return ok ? 0 : 1; } diff --git a/fnet/src/examples/frt/rpc/rpc_callback_server.cpp b/fnet/src/examples/frt/rpc/rpc_callback_server.cpp index 33207282bcc..ac7b34ebda0 100644 --- a/fnet/src/examples/frt/rpc/rpc_callback_server.cpp +++ b/fnet/src/examples/frt/rpc/rpc_callback_server.cpp @@ -2,6 +2,7 @@ #include <vespa/fnet/frt/frt.h> #include <vespa/fastos/app.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("rpc_callback_server"); @@ -12,9 +13,7 @@ struct RPC : public FRT_Invokable void Init(FRT_Supervisor *s); }; -void -RPC::CallBack(FRT_RPCRequest *req) -{ +void do_callback(FRT_RPCRequest *req) { FNET_Connection *conn = req->GetConnection(); FRT_RPCRequest *cb = new FRT_RPCRequest(); cb->SetMethodName(req->GetParams()->GetValue(0)._string._str); @@ -25,6 +24,14 @@ RPC::CallBack(FRT_RPCRequest *req) cb->GetErrorMessage()); } cb->SubRef(); + req->Return(); +} + +void +RPC::CallBack(FRT_RPCRequest *req) +{ + req->Detach(); + std::thread(do_callback, req).detach(); } void @@ -32,7 +39,7 @@ RPC::Init(FRT_Supervisor *s) { FRT_ReflectionBuilder rb(s); //------------------------------------------------------------------- - rb.DefineMethod("callBack", "s", "", false, + rb.DefineMethod("callBack", "s", "", FRT_METHOD(RPC::CallBack), this); //------------------------------------------------------------------- } diff --git a/fnet/src/examples/frt/rpc/rpc_server.cpp b/fnet/src/examples/frt/rpc/rpc_server.cpp index 8947663216e..03d618133c9 100644 --- a/fnet/src/examples/frt/rpc/rpc_server.cpp +++ b/fnet/src/examples/frt/rpc/rpc_server.cpp @@ -28,21 +28,21 @@ RPCServer::InitRPC(FRT_Supervisor *s) { FRT_ReflectionBuilder rb(s); //------------------------------------------------------------------- - rb.DefineMethod("concat", "ss", "s", true, + rb.DefineMethod("concat", "ss", "s", FRT_METHOD(RPCServer::RPC_concat), this); rb.MethodDesc("Concatenate two strings"); rb.ParamDesc("string1", "a string"); rb.ParamDesc("string2", "another string"); rb.ReturnDesc("ret", "the concatenation of string1 and string2"); //------------------------------------------------------------------- - rb.DefineMethod("addFloat", "ff", "f", true, + rb.DefineMethod("addFloat", "ff", "f", FRT_METHOD(RPCServer::RPC_addFloat), this); rb.MethodDesc("Add two floats"); rb.ParamDesc("float1", "a float"); rb.ParamDesc("float2", "another float"); rb.ReturnDesc("ret", "float1 + float2"); //------------------------------------------------------------------- - rb.DefineMethod("addDouble", "dd", "d", true, + rb.DefineMethod("addDouble", "dd", "d", FRT_METHOD(RPCServer::RPC_addDouble), this); rb.MethodDesc("Add two doubles"); rb.ParamDesc("double1", "a double"); diff --git a/fnet/src/examples/proxy/proxy.cpp b/fnet/src/examples/proxy/proxy.cpp index 653b445581f..a01a16ead9c 100644 --- a/fnet/src/examples/proxy/proxy.cpp +++ b/fnet/src/examples/proxy/proxy.cpp @@ -227,7 +227,6 @@ Proxy::Main() if (listener != nullptr) listener->SubRef(); - _transport.SetLogStats(true); FNET_SignalShutDown ssd(_transport); _transport.Main(); return 0; diff --git a/fnet/src/tests/frt/method_pt/method_pt.cpp b/fnet/src/tests/frt/method_pt/method_pt.cpp index db5905d6871..5417fddceeb 100644 --- a/fnet/src/tests/frt/method_pt/method_pt.cpp +++ b/fnet/src/tests/frt/method_pt/method_pt.cpp @@ -207,35 +207,35 @@ void initTest() { //------------------------------------------------------------------- - rb.DefineMethod("simpleMethod", "", "", true, + rb.DefineMethod("simpleMethod", "", "", FRT_METHOD(SimpleHandler::RPC_Method), _simpleHandler); //------------------------------------------------------------------- - rb.DefineMethod("mediumMethod1", "", "", true, + rb.DefineMethod("mediumMethod1", "", "", FRT_METHOD(MediumHandler1::RPC_Method), _mediumHandler1); - rb.DefineMethod("mediumMethod2", "", "", true, + rb.DefineMethod("mediumMethod2", "", "", FRT_METHOD(MediumHandler2::RPC_Method), _mediumHandler2); - rb.DefineMethod("mediumMethod3", "", "", true, + rb.DefineMethod("mediumMethod3", "", "", FRT_METHOD(MediumHandler3::RPC_Method), _mediumHandler3); //------------------------------------------------------------------- - rb.DefineMethod("complexMethod1", "", "", true, + rb.DefineMethod("complexMethod1", "", "", FRT_METHOD(ComplexHandler1::RPC_Method), _complexHandler1); - rb.DefineMethod("complexMethod2", "", "", true, + rb.DefineMethod("complexMethod2", "", "", FRT_METHOD(ComplexHandler2::RPC_Method), _complexHandler2); - rb.DefineMethod("complexMethod3", "", "", true, + rb.DefineMethod("complexMethod3", "", "", FRT_METHOD(ComplexHandler3::RPC_Method), _complexHandler3); diff --git a/fnet/src/tests/frt/parallel_rpc/parallel_rpc_test.cpp b/fnet/src/tests/frt/parallel_rpc/parallel_rpc_test.cpp index 478e4b14f02..59ca9d4ccc0 100644 --- a/fnet/src/tests/frt/parallel_rpc/parallel_rpc_test.cpp +++ b/fnet/src/tests/frt/parallel_rpc/parallel_rpc_test.cpp @@ -3,16 +3,19 @@ #include <vespa/vespalib/util/stringfmt.h> #include <vespa/fnet/frt/frt.h> #include <vespa/vespalib/util/benchmark_timer.h> +#include <vespa/vespalib/net/crypto_engine.h> +#include <vespa/vespalib/net/tls/tls_crypto_engine.h> +#include <vespa/vespalib/test/make_tls_options_for_testing.h> #include <thread> -using vespalib::BenchmarkTimer; +using namespace vespalib; struct Rpc : FRT_Invokable { FastOS_ThreadPool thread_pool; FNET_Transport transport; FRT_Supervisor orb; - Rpc(size_t num_threads) - : thread_pool(128 * 1024), transport(num_threads), orb(&transport, &thread_pool) {} + Rpc(CryptoEngine::SP crypto, size_t num_threads) + : thread_pool(128 * 1024), transport(crypto, num_threads), orb(&transport, &thread_pool) {} void start() { ASSERT_TRUE(transport.Start(&thread_pool)); } @@ -31,13 +34,13 @@ struct Rpc : FRT_Invokable { struct Server : Rpc { uint32_t port; - Server(size_t num_threads) : Rpc(num_threads), port(listen()) { + Server(CryptoEngine::SP crypto, size_t num_threads) : Rpc(crypto, num_threads), port(listen()) { init_rpc(); start(); } void init_rpc() { FRT_ReflectionBuilder rb(&orb); - rb.DefineMethod("inc", "l", "l", true, FRT_METHOD(Server::rpc_inc), this); + rb.DefineMethod("inc", "l", "l", FRT_METHOD(Server::rpc_inc), this); rb.MethodDesc("increment a 64-bit integer"); rb.ParamDesc("in", "an integer (64 bit)"); rb.ReturnDesc("out", "in + 1 (64 bit)"); @@ -51,7 +54,7 @@ struct Server : Rpc { struct Client : Rpc { uint32_t port; - Client(size_t num_threads, const Server &server) : Rpc(num_threads), port(server.port) { + Client(CryptoEngine::SP crypto, size_t num_threads, const Server &server) : Rpc(crypto, num_threads), port(server.port) { start(); } FRT_Target *connect() { return Rpc::connect(port); } @@ -93,8 +96,8 @@ void perform_test(size_t thread_id, Client &client, Result &result) { seq = ret; }; size_t loop_cnt = 128; - BenchmarkTimer::benchmark(invoke, invoke, 1.0); - BenchmarkTimer timer(3.0); + BenchmarkTimer::benchmark(invoke, invoke, 0.5); + BenchmarkTimer timer(1.5); while (timer.has_budget()) { timer.before(); for (size_t i = 0; i < loop_cnt; ++i) { @@ -103,7 +106,7 @@ void perform_test(size_t thread_id, Client &client, Result &result) { timer.after(); } double t = timer.min_time(); - BenchmarkTimer::benchmark(invoke, invoke, 1.0); + BenchmarkTimer::benchmark(invoke, invoke, 0.5); EXPECT_GREATER_EQUAL(seq, loop_cnt); result.req_per_sec[thread_id] = double(loop_cnt) / t; req->SubRef(); @@ -114,16 +117,26 @@ void perform_test(size_t thread_id, Client &client, Result &result) { } } -TEST_MT_FFF("parallel rpc with 1/1 transport threads and 128 user threads", - 128, Server(1), Client(1, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } +CryptoEngine::SP null_crypto = std::make_shared<NullCryptoEngine>(); +CryptoEngine::SP xor_crypto = std::make_shared<XorCryptoEngine>(); +// CryptoEngine::SP tls_crypto = std::make_shared<vespalib::TlsCryptoEngine>(vespalib::test::make_tls_options_for_testing()); -TEST_MT_FFF("parallel rpc with 1/8 transport threads and 128 user threads", - 128, Server(8), Client(1, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } +TEST_MT_FFF("parallel rpc with 1/1 transport threads and 128 user threads (no encryption)", + 128, Server(null_crypto, 1), Client(null_crypto, 1, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } -TEST_MT_FFF("parallel rpc with 8/1 transport threads and 128 user threads", - 128, Server(1), Client(8, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } +TEST_MT_FFF("parallel rpc with 1/1 transport threads and 128 user threads (xor encryption)", + 128, Server(xor_crypto, 1), Client(xor_crypto, 1, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } -TEST_MT_FFF("parallel rpc with 8/8 transport threads and 128 user threads", - 128, Server(8), Client(8, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } +// TEST_MT_FFF("parallel rpc with 1/1 transport threads and 128 user threads (tls encryption)", +// 128, Server(tls_crypto, 1), Client(tls_crypto, 1, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } + +TEST_MT_FFF("parallel rpc with 8/8 transport threads and 128 user threads (no encryption)", + 128, Server(null_crypto, 8), Client(null_crypto, 8, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } + +TEST_MT_FFF("parallel rpc with 8/8 transport threads and 128 user threads (xor encryption)", + 128, Server(xor_crypto, 8), Client(xor_crypto, 8, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } + +// TEST_MT_FFF("parallel rpc with 8/8 transport threads and 128 user threads (tls encryption)", +// 128, Server(tls_crypto, 8), Client(tls_crypto, 8, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/fnet/src/tests/frt/rpc/CMakeLists.txt b/fnet/src/tests/frt/rpc/CMakeLists.txt index b426ed42397..4c1883ee7bd 100644 --- a/fnet/src/tests/frt/rpc/CMakeLists.txt +++ b/fnet/src/tests/frt/rpc/CMakeLists.txt @@ -6,6 +6,8 @@ vespa_add_executable(fnet_invoke_test_app TEST fnet ) vespa_add_test(NAME fnet_invoke_test_app COMMAND fnet_invoke_test_app) +vespa_add_test(NAME fnet_invoke_test_app_xor COMMAND fnet_invoke_test_app ENVIRONMENT "CRYPTOENGINE=xor") +# vespa_add_test(NAME fnet_invoke_test_app_tls COMMAND fnet_invoke_test_app ENVIRONMENT "CRYPTOENGINE=tls") vespa_add_executable(fnet_detach_return_invoke_test_app TEST SOURCES detach_return_invoke.cpp @@ -20,6 +22,8 @@ vespa_add_executable(fnet_session_test_app TEST fnet ) vespa_add_test(NAME fnet_session_test_app COMMAND fnet_session_test_app) +vespa_add_test(NAME fnet_session_test_app_xor COMMAND fnet_session_test_app ENVIRONMENT "CRYPTOENGINE=xor") +# vespa_add_test(NAME fnet_session_test_app_tls COMMAND fnet_session_test_app ENVIRONMENT "CRYPTOENGINE=tls") vespa_add_executable(fnet_sharedblob_test_app TEST SOURCES sharedblob.cpp diff --git a/fnet/src/tests/frt/rpc/detach_return_invoke.cpp b/fnet/src/tests/frt/rpc/detach_return_invoke.cpp index 54a891261c2..ab21c62bb68 100644 --- a/fnet/src/tests/frt/rpc/detach_return_invoke.cpp +++ b/fnet/src/tests/frt/rpc/detach_return_invoke.cpp @@ -20,7 +20,7 @@ struct Server : public FRT_Invokable Server(FRT_Supervisor &s, Receptor &r) : orb(s), receptor(r) { FRT_ReflectionBuilder rb(&s); - rb.DefineMethod("hook", "", "", true, + rb.DefineMethod("hook", "", "", FRT_METHOD(Server::rpc_hook), this); } diff --git a/fnet/src/tests/frt/rpc/invoke.cpp b/fnet/src/tests/frt/rpc/invoke.cpp index f44a58dd8b3..e3bd662214f 100644 --- a/fnet/src/tests/frt/rpc/invoke.cpp +++ b/fnet/src/tests/frt/rpc/invoke.cpp @@ -6,6 +6,11 @@ //------------------------------------------------------------- +#include "my_crypto_engine.hpp" +vespalib::CryptoEngine::SP crypto; + +//------------------------------------------------------------- + std::mutex _delayedReturnCntLock; uint32_t _delayedReturnCnt = 0; @@ -119,7 +124,7 @@ public: assert(_echo_stash != nullptr && _echo_args != nullptr); FRT_ReflectionBuilder rb(supervisor); - rb.DefineMethod("echo", "*", "*", true, + rb.DefineMethod("echo", "*", "*", FRT_METHOD(EchoTest::RPC_Echo), this); FRT_Values *args = _echo_args; @@ -220,17 +225,15 @@ public: { FRT_ReflectionBuilder rb(supervisor); - rb.DefineMethod("inc", "i", "i", true, + rb.DefineMethod("inc", "i", "i", FRT_METHOD(TestRPC::RPC_Inc), this); - rb.DefineMethod("setValue", "i", "", true, + rb.DefineMethod("setValue", "i", "", FRT_METHOD(TestRPC::RPC_SetValue), this); - rb.DefineMethod("incValue", "", "", true, + rb.DefineMethod("incValue", "", "", FRT_METHOD(TestRPC::RPC_IncValue), this); - rb.DefineMethod("getValue", "", "i", true, + rb.DefineMethod("getValue", "", "i", FRT_METHOD(TestRPC::RPC_GetValue), this); - rb.DefineMethod("testFast", "iiibb", "i", true, - FRT_METHOD(TestRPC::RPC_Test), this); - rb.DefineMethod("testSlow", "iiibb", "i", false, + rb.DefineMethod("testFast", "iiibb", "i", FRT_METHOD(TestRPC::RPC_Test), this); } @@ -359,7 +362,6 @@ const char phase_names[PHASE_ZZZ][32] = enum { TIMING_NULL = 0, TIMING_INSTANT, - TIMING_NON_INSTANT, TIMING_ZZZ }; @@ -367,7 +369,6 @@ const char timing_names[TIMING_ZZZ][32] = { "nullptr", "INSTANT", - "NON-INSTANT" }; enum { @@ -400,8 +401,8 @@ struct State { FRT_RPCRequest *_req; State() - : _client(), - _server(), + : _client(crypto), + _server(crypto), _rpc(&_server, _client.GetScheduler()), _echo(), _peerSpec(), @@ -446,17 +447,10 @@ struct State { void PrepareTestMethod() { NewReq(); - bool instant = (_timing == TIMING_INSTANT); - if (_timing != TIMING_INSTANT && - _timing != TIMING_NON_INSTANT) - { + if (_timing != TIMING_INSTANT) { ASSERT_TRUE(false); // consult your dealer... } - if (instant) { - _req->SetMethodName("testFast"); - } else { - _req->SetMethodName("testSlow"); - } + _req->SetMethodName("testFast"); } void SetTestParams(uint32_t value, uint32_t delay, @@ -923,10 +917,13 @@ TEST_F("invoke test", State()) { EXPECT_TRUE(_phase_simple_cnt == 1); EXPECT_TRUE(_phase_void_cnt == 1); EXPECT_TRUE(_phase_speed_cnt == 1); - EXPECT_TRUE(_phase_advanced_cnt == 4); - EXPECT_TRUE(_phase_error_cnt == 4); - EXPECT_TRUE(_phase_abort_cnt == 4); + EXPECT_TRUE(_phase_advanced_cnt == 2); + EXPECT_TRUE(_phase_error_cnt == 2); + EXPECT_TRUE(_phase_abort_cnt == 2); EXPECT_TRUE(_phase_echo_cnt == 1); } -TEST_MAIN() { TEST_RUN_ALL(); } +TEST_MAIN() { + crypto = my_crypto_engine(); + TEST_RUN_ALL(); +} diff --git a/fnet/src/tests/frt/rpc/my_crypto_engine.hpp b/fnet/src/tests/frt/rpc/my_crypto_engine.hpp new file mode 100644 index 00000000000..6cd8d47e917 --- /dev/null +++ b/fnet/src/tests/frt/rpc/my_crypto_engine.hpp @@ -0,0 +1,22 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/net/tls/tls_crypto_engine.h> +#include <vespa/vespalib/test/make_tls_options_for_testing.h> + +vespalib::CryptoEngine::SP my_crypto_engine() { + const char *env_str = getenv("CRYPTOENGINE"); + if (!env_str) { + fprintf(stderr, "crypto engine: null\n"); + return std::make_shared<vespalib::NullCryptoEngine>(); + } + std::string engine(env_str); + if (engine == "xor") { + fprintf(stderr, "crypto engine: xor\n"); + return std::make_shared<vespalib::XorCryptoEngine>(); + } else if (engine == "tls") { + fprintf(stderr, "crypto engine: tls\n"); + return std::make_shared<vespalib::TlsCryptoEngine>(vespalib::test::make_tls_options_for_testing()); + } + TEST_FATAL(("invalid crypto engine: " + engine).c_str()); + abort(); +} diff --git a/fnet/src/tests/frt/rpc/session.cpp b/fnet/src/tests/frt/rpc/session.cpp index 2e920ac9f98..93f14647e21 100644 --- a/fnet/src/tests/frt/rpc/session.cpp +++ b/fnet/src/tests/frt/rpc/session.cpp @@ -4,6 +4,12 @@ #include <vespa/fnet/frt/frt.h> #include <mutex> +//------------------------------------------------------------- + +#include "my_crypto_engine.hpp" +vespalib::CryptoEngine::SP crypto; + +//------------------------------------------------------------- class Session { @@ -71,9 +77,9 @@ struct RPC : public FRT_Invokable void Init(FRT_Supervisor *s) { FRT_ReflectionBuilder rb(s); - rb.DefineMethod("getValue", "", "i", true, + rb.DefineMethod("getValue", "", "i", FRT_METHOD(RPC::GetValue), this); - rb.DefineMethod("setValue", "i", "", true, + rb.DefineMethod("setValue", "i", "", FRT_METHOD(RPC::SetValue), this); s->SetSessionInitHook(FRT_METHOD(RPC::InitSession), this); s->SetSessionFiniHook(FRT_METHOD(RPC::FiniSession), this); @@ -82,7 +88,7 @@ struct RPC : public FRT_Invokable TEST("session") { RPC rpc; - FRT_Supervisor orb; + FRT_Supervisor orb(crypto); char spec[64]; rpc.Init(&orb); ASSERT_TRUE(orb.Listen("tcp/0")); @@ -121,4 +127,7 @@ TEST("session") { EXPECT_TRUE(!rpc.bogusFini); }; -TEST_MAIN() { TEST_RUN_ALL(); } +TEST_MAIN() { + crypto = my_crypto_engine(); + TEST_RUN_ALL(); +} diff --git a/fnet/src/tests/frt/rpc/sharedblob.cpp b/fnet/src/tests/frt/rpc/sharedblob.cpp index 10eaad9c013..a48ecbb1da7 100644 --- a/fnet/src/tests/frt/rpc/sharedblob.cpp +++ b/fnet/src/tests/frt/rpc/sharedblob.cpp @@ -176,7 +176,7 @@ TEST("testImplicitShared") { ServerSampler serverSampler(dataSet, req); { FRT_ReflectionBuilder rb(&orb); - rb.DefineMethod("test", "*", "*", true, + rb.DefineMethod("test", "*", "*", FRT_METHOD(ServerSampler::RPC_test), &serverSampler); } orb.Listen(0); diff --git a/fnet/src/tests/info/info.cpp b/fnet/src/tests/info/info.cpp index cd0364cad6f..f76e66c2af6 100644 --- a/fnet/src/tests/info/info.cpp +++ b/fnet/src/tests/info/info.cpp @@ -24,7 +24,7 @@ struct RPC : public FRT_Invokable { FRT_ReflectionBuilder rb(s); //------------------------------------------------------------------- - rb.DefineMethod("getInfo", "", "sssii", true, + rb.DefineMethod("getInfo", "", "sssii", FRT_METHOD(RPC::GetInfo), this); // FastOS version // FNET version @@ -70,10 +70,10 @@ TEST("info") { TEST("size of important objects") { - EXPECT_EQUAL(184u, sizeof(FNET_IOComponent)); + EXPECT_EQUAL(168u, sizeof(FNET_IOComponent)); EXPECT_EQUAL(32u, sizeof(FNET_Channel)); EXPECT_EQUAL(40u, sizeof(FNET_PacketQueue_NoLock)); - EXPECT_EQUAL(488u, sizeof(FNET_Connection)); + EXPECT_EQUAL(472u, sizeof(FNET_Connection)); EXPECT_EQUAL(48u, sizeof(std::condition_variable)); EXPECT_EQUAL(56u, sizeof(FNET_DataBuffer)); EXPECT_EQUAL(24u, sizeof(FastOS_Time)); diff --git a/fnet/src/vespa/fnet/CMakeLists.txt b/fnet/src/vespa/fnet/CMakeLists.txt index 20badc5c489..4b9d818e5ed 100644 --- a/fnet/src/vespa/fnet/CMakeLists.txt +++ b/fnet/src/vespa/fnet/CMakeLists.txt @@ -17,7 +17,6 @@ vespa_add_library(fnet scheduler.cpp signalshutdown.cpp simplepacketstreamer.cpp - stats.cpp task.cpp transport.cpp transport_thread.cpp diff --git a/fnet/src/vespa/fnet/config.cpp b/fnet/src/vespa/fnet/config.cpp index 01bc76791de..a546d38f78b 100644 --- a/fnet/src/vespa/fnet/config.cpp +++ b/fnet/src/vespa/fnet/config.cpp @@ -3,12 +3,9 @@ #include "config.h" FNET_Config::FNET_Config() - : _minEventTimeOut(0), - _pingInterval(0), - _iocTimeOut(0), + : _iocTimeOut(0), _maxInputBufferSize(0x10000), _maxOutputBufferSize(0x10000), - _tcpNoDelay(true), - _logStats(false), - _directWrite(false) -{ } + _tcpNoDelay(true) +{ +} diff --git a/fnet/src/vespa/fnet/config.h b/fnet/src/vespa/fnet/config.h index e92540ce0bc..3f34c1511b6 100644 --- a/fnet/src/vespa/fnet/config.h +++ b/fnet/src/vespa/fnet/config.h @@ -11,15 +11,10 @@ class FNET_Config { public: - uint32_t _minEventTimeOut; - uint32_t _pingInterval; uint32_t _iocTimeOut; uint32_t _maxInputBufferSize; uint32_t _maxOutputBufferSize; bool _tcpNoDelay; - bool _logStats; - bool _directWrite; FNET_Config(); }; - diff --git a/fnet/src/vespa/fnet/connection.cpp b/fnet/src/vespa/fnet/connection.cpp index 07086ca54a2..e028afe5deb 100644 --- a/fnet/src/vespa/fnet/connection.cpp +++ b/fnet/src/vespa/fnet/connection.cpp @@ -110,13 +110,6 @@ FNET_Connection::SetState(State state) } if (oldstate < FNET_CLOSING && state >= FNET_CLOSING) { - if (_flags._writeLock) { - _flags._discarding = true; - while (_flags._writeLock) - _ioc_cond.wait(guard); - _flags._discarding = false; - } - while (!_queue.IsEmpty_NoLock() || !_myQueue.IsEmpty_NoLock()) { _flags._discarding = true; _queue.FlushPackets_NoLock(&_myQueue); @@ -232,15 +225,15 @@ FNET_Connection::handshake() case vespalib::CryptoSocket::HandshakeResult::DONE: { EnableReadEvent(true); EnableWriteEvent(writePendingAfterConnect()); + _flags._framed = (_socket->min_read_buffer_size() > 1); size_t chunk_size = std::max(size_t(FNET_READ_SIZE), _socket->min_read_buffer_size()); - uint32_t ignore_stats = 0; ssize_t res = 0; do { // drain input pipeline _input.EnsureFree(chunk_size); res = _socket->drain(_input.GetFree(), _input.GetFreeLen()); if (res > 0) { _input.FreeToData((uint32_t)res); - broken = !handle_packets(ignore_stats); + broken = !handle_packets(); _input.resetIfEmpty(); } } while ((res > 0) && !broken); } @@ -258,7 +251,7 @@ FNET_Connection::handshake() } bool -FNET_Connection::handle_packets(uint32_t &read_packets) +FNET_Connection::handle_packets() { bool broken = false; for (bool done = false; !done;) { // handle each complete packet in the buffer. @@ -268,7 +261,6 @@ FNET_Connection::handle_packets(uint32_t &read_packets) &broken); } if (_flags._gotheader && (_input.GetDataLen() >= _packetLength)) { - read_packets++; HandlePacket(_packetLength, _packetCode, _packetCHID); _flags._gotheader = false; // reset header flag. } else { @@ -282,26 +274,26 @@ bool FNET_Connection::Read() { size_t chunk_size = std::max(size_t(FNET_READ_SIZE), _socket->min_read_buffer_size()); - uint32_t readData = 0; // total data read - uint32_t readPackets = 0; // total packets read int readCnt = 0; // read count bool broken = false; // is this conn broken ? + int my_errno = 0; // sample and preserve errno ssize_t res; // single read result _input.EnsureFree(chunk_size); res = _socket->read(_input.GetFree(), _input.GetFreeLen()); + my_errno = errno; readCnt++; while (res > 0) { _input.FreeToData((uint32_t)res); - readData += (uint32_t)res; - broken = !handle_packets(readPackets); + broken = !handle_packets(); _input.resetIfEmpty(); - if (broken || (_input.GetFreeLen() > 0) || (readCnt >= FNET_READ_REDO)) { + if (broken || ((_input.GetFreeLen() > 0) && !_flags._framed) || (readCnt >= FNET_READ_REDO)) { goto done_read; } _input.EnsureFree(chunk_size); res = _socket->read(_input.GetFree(), _input.GetFreeLen()); + my_errno = errno; readCnt++; } @@ -310,28 +302,23 @@ done_read: while ((res > 0) && !broken) { // drain input pipeline _input.EnsureFree(chunk_size); res = _socket->drain(_input.GetFree(), _input.GetFreeLen()); - readCnt++; + my_errno = errno; if (res > 0) { _input.FreeToData((uint32_t)res); - readData += (uint32_t)res; - broken = !handle_packets(readPackets); + broken = !handle_packets(); _input.resetIfEmpty(); } else if (res == 0) { // fully drained -> EWOULDBLOCK - errno = EWOULDBLOCK; + my_errno = EWOULDBLOCK; res = -1; } } - if (readData > 0) { - UpdateTimeOut(); - CountDataRead(readData); - CountPacketRead(readPackets); - uint32_t maxSize = GetConfig()->_maxInputBufferSize; - if (maxSize > 0 && _input.GetBufSize() > maxSize) - { - if (!_flags._gotheader || _packetLength < maxSize) { - _input.Shrink(maxSize); - } + UpdateTimeOut(); + uint32_t maxSize = GetConfig()->_maxInputBufferSize; + if (maxSize > 0 && _input.GetBufSize() > maxSize) + { + if (!_flags._gotheader || _packetLength < maxSize) { + _input.Shrink(maxSize); } } @@ -339,9 +326,9 @@ done_read: if (res == 0) { broken = true; // handle EOF } else { // res < 0 - broken = ((errno != EWOULDBLOCK) && (errno != EAGAIN)); - if (broken && (errno != ECONNRESET)) { - LOG(debug, "Connection(%s): read error: %d", GetSpec(), errno); + broken = ((my_errno != EWOULDBLOCK) && (my_errno != EAGAIN)); + if (broken && (my_errno != ECONNRESET)) { + LOG(debug, "Connection(%s): read error: %d", GetSpec(), my_errno); } } } @@ -351,13 +338,13 @@ done_read: bool -FNET_Connection::Write(bool direct) +FNET_Connection::Write() { + size_t chunk_size = std::max(size_t(FNET_WRITE_SIZE), _socket->min_read_buffer_size()); uint32_t my_write_work = 0; - uint32_t writtenData = 0; // total data written - uint32_t writtenPackets = 0; // total packets written int writeCnt = 0; // write count bool broken = false; // is this conn broken ? + int my_errno = 0; // sample and preserve errno ssize_t res; // single write result FNET_Packet *packet; @@ -367,14 +354,13 @@ FNET_Connection::Write(bool direct) // fill output buffer - while (_output.GetDataLen() < FNET_WRITE_SIZE) { + while (_output.GetDataLen() < chunk_size) { if (_myQueue.IsEmpty_NoLock()) break; packet = _myQueue.DequeuePacket_NoLock(&context); if (packet->IsRegularPacket()) { // ignore non-regular packets _streamer->Encode(packet, context._value.INT, &_output); - writtenPackets++; } packet->Free(); } @@ -387,10 +373,10 @@ FNET_Connection::Write(bool direct) // write data res = _socket->write(_output.GetData(), _output.GetDataLen()); + my_errno = errno; writeCnt++; if (res > 0) { _output.DataToDead((uint32_t)res); - writtenData += (uint32_t)res; _output.resetIfEmpty(); } } while (res > 0 && @@ -404,26 +390,26 @@ FNET_Connection::Write(bool direct) if (res >= 0) { // flush output pipeline res = _socket->flush(); + my_errno = errno; while (res > 0) { res = _socket->flush(); + my_errno = errno; } } - if (writtenData > 0) { - uint32_t maxSize = GetConfig()->_maxOutputBufferSize; - if (maxSize > 0 && _output.GetBufSize() > maxSize) { - _output.Shrink(maxSize); - } + uint32_t maxSize = GetConfig()->_maxOutputBufferSize; + if (maxSize > 0 && _output.GetBufSize() > maxSize) { + _output.Shrink(maxSize); } if (res < 0) { - if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) { + if ((my_errno == EWOULDBLOCK) || (my_errno == EAGAIN)) { ++my_write_work; // incomplete write/flush } else { broken = true; } - if (broken && (errno != ECONNRESET)) { - LOG(debug, "Connection(%s): write error: %d", GetSpec(), errno); + if (broken && (my_errno != ECONNRESET)) { + LOG(debug, "Connection(%s): write error: %d", GetSpec(), my_errno); } } @@ -431,35 +417,11 @@ FNET_Connection::Write(bool direct) _writeWork = _queue.GetPacketCnt_NoLock() + _myQueue.GetPacketCnt_NoLock() + my_write_work; - _flags._writeLock = false; - if (_flags._discarding) { - _ioc_cond.notify_all(); - } bool writePending = (_writeWork > 0); - if (direct) { // direct write (from post packet) - if (writtenData > 0) { - CountDirectDataWrite(writtenData); - CountDirectPacketWrite(writtenPackets); - } - if (writePending) { - AddRef_NoLock(); - guard.unlock(); - if (broken) { - Owner()->Close(this, /* needRef = */ false); - } else { - Owner()->EnableWrite(this, /* needRef = */ false); - } - } - } else { // normal write (from event loop) - guard.unlock(); - if (writtenData > 0) { - CountDataWrite(writtenData); - CountPacketWrite(writtenPackets); - } - if (!writePending) - EnableWriteEvent(false); - } + guard.unlock(); + if (!writePending) + EnableWriteEvent(false); return !broken; } @@ -544,7 +506,6 @@ FNET_Connection::~FNET_Connection() delete _adminChannel; } assert(_cleanup == nullptr); - assert(!_flags._writeLock); } @@ -698,32 +659,15 @@ FNET_Connection::PostPacket(FNET_Packet *packet, uint32_t chid) writeWork = _writeWork; _writeWork++; _queue.QueuePacket_NoLock(packet, FNET_Context(chid)); - if (writeWork == 0 && !_flags._writeLock && - _state == FNET_CONNECTED) - { - if (GetConfig()->_directWrite) { - _flags._writeLock = true; - _queue.FlushPackets_NoLock(&_myQueue); - guard.unlock(); - Write(true); - } else { - AddRef_NoLock(); - guard.unlock(); - Owner()->EnableWrite(this, /* needRef = */ false); - } + if ((writeWork == 0) && (_state == FNET_CONNECTED)) { + AddRef_NoLock(); + guard.unlock(); + Owner()->EnableWrite(this, /* needRef = */ false); } return true; } -uint32_t -FNET_Connection::GetQueueLen() -{ - std::lock_guard<std::mutex> guard(_ioc_lock); - return _queue.GetPacketCnt_NoLock() + _myQueue.GetPacketCnt_NoLock(); -} - - void FNET_Connection::Sync() { @@ -797,15 +741,9 @@ FNET_Connection::HandleWriteEvent() case FNET_CONNECTED: { std::unique_lock<std::mutex> guard(_ioc_lock); - if (_flags._writeLock) { - guard.unlock(); - EnableWriteEvent(false); - return true; - } - _flags._writeLock = true; _queue.FlushPackets_NoLock(&_myQueue); } - broken = !Write(false); + broken = !Write(); break; case FNET_CLOSING: case FNET_CLOSED: diff --git a/fnet/src/vespa/fnet/connection.h b/fnet/src/vespa/fnet/connection.h index 120c675dc70..44ad3fea97d 100644 --- a/fnet/src/vespa/fnet/connection.h +++ b/fnet/src/vespa/fnet/connection.h @@ -68,16 +68,16 @@ private: struct Flags { Flags() : _gotheader(false), - _writeLock(false), _inCallback(false), _callbackWait(false), - _discarding(false) + _discarding(false), + _framed(false) { } bool _gotheader; - bool _writeLock; bool _inCallback; bool _callbackWait; bool _discarding; + bool _framed; }; struct ResolveHandler : public vespalib::AsyncResolver::ResultHandler { FNET_Connection *connection; @@ -212,9 +212,8 @@ private: * for each one. * * @return false if socket is broken. - * @param read_packets count read packets here **/ - bool handle_packets(uint32_t &read_packets); + bool handle_packets(); /** * Read incoming data from socket. @@ -227,10 +226,8 @@ private: * Write outgoing data to socket. * * @return false if socket is broken. - * @param direct is this a direct write (called directly from - * postpacket, without waiting for a write event) **/ - bool Write(bool direct); + bool Write(); bool writePendingAfterConnect(); public: @@ -452,19 +449,6 @@ public: /** - * Obtain the number of packets located in the output queue for this - * connection. Note that this number is volatile and should only be - * used as an estimate. Also note that since a queue latching - * strategy is used, this method requires a mutex lock/unlock and is - * therefore not as cheap as may be expected. - * - * @return number of packets currently located in the output queue - * for this connection. - **/ - uint32_t GetQueueLen(); - - - /** * Sync with this connection. When this method is invoked it will * block until all packets currently posted on this connection is * encoded into the output buffer. Also, the amount of data in the diff --git a/fnet/src/vespa/fnet/databuffer.cpp b/fnet/src/vespa/fnet/databuffer.cpp index 3b2e7759c99..74a8bc4e12c 100644 --- a/fnet/src/vespa/fnet/databuffer.cpp +++ b/fnet/src/vespa/fnet/databuffer.cpp @@ -13,7 +13,6 @@ FNET_DataBuffer::FNET_DataBuffer(uint32_t len) if (len > 0) { Alloc::alloc(len).swap(_ownedBuf); - memset(_ownedBuf.get(), 0x55, len); _bufstart = static_cast<char *>(_ownedBuf.get()); assert(_bufstart != nullptr); } else { // len == 0 @@ -70,7 +69,6 @@ FNET_DataBuffer::Shrink(uint32_t newsize) } Alloc newBuf(Alloc::alloc(newsize)); - memset(newBuf.get(), 0x55, newsize); memcpy(newBuf.get(), _datapt, GetDataLen()); _ownedBuf.swap(newBuf); _bufstart = static_cast<char *>(_ownedBuf.get()); @@ -95,7 +93,6 @@ FNET_DataBuffer::Pack(uint32_t needbytes) bufsize *= 2; Alloc newBuf(Alloc::alloc(bufsize)); - memset(newBuf.get(), 0x55, bufsize); memcpy(newBuf.get(), _datapt, GetDataLen()); _ownedBuf.swap(newBuf); _bufstart = static_cast<char *>(_ownedBuf.get()); diff --git a/fnet/src/vespa/fnet/fnet.h b/fnet/src/vespa/fnet/fnet.h index 5a3a8b28942..c7570e025ec 100644 --- a/fnet/src/vespa/fnet/fnet.h +++ b/fnet/src/vespa/fnet/fnet.h @@ -32,8 +32,6 @@ class FNET_Packet; class FNET_PacketQueue; class FNET_Scheduler; class FNET_SimplePacketStreamer; -class FNET_StatCounters; -class FNET_Stats; class FNET_Task; class FNET_Transport; class FNET_TransportThread; @@ -52,7 +50,6 @@ class FNET_TransportThread; #include "task.h" #include "scheduler.h" #include "config.h" -#include "stats.h" #include "databuffer.h" #include "packet.h" #include "dummypacket.h" diff --git a/fnet/src/vespa/fnet/frt/invoker.cpp b/fnet/src/vespa/fnet/frt/invoker.cpp index f2dc331c707..b174c3a710e 100644 --- a/fnet/src/vespa/fnet/frt/invoker.cpp +++ b/fnet/src/vespa/fnet/frt/invoker.cpp @@ -64,18 +64,14 @@ FRT_RPCInvoker::FRT_RPCInvoker(FRT_Supervisor *supervisor, req->SetReturnHandler(this); } -bool FRT_RPCInvoker::IsInstant() { - return _method->IsInstant(); -} - -bool FRT_RPCInvoker::Invoke(bool freeChannel) +bool FRT_RPCInvoker::Invoke() { bool detached = false; _req->SetDetachedPT(&detached); (_method->GetHandler()->*_method->GetMethod())(_req); if (detached) return false; - HandleDone(freeChannel); + HandleDone(false); return true; } @@ -120,13 +116,6 @@ FRT_RPCInvoker::GetConnection() return _req->GetContext()._value.CHANNEL->GetConnection(); } - -void -FRT_RPCInvoker::Run(FastOS_ThreadInterface *, void *) -{ - Invoke(true); -} - //----------------------------------------------------------------------------- void FRT_HookInvoker::Invoke() diff --git a/fnet/src/vespa/fnet/frt/invoker.h b/fnet/src/vespa/fnet/frt/invoker.h index 15d74017200..64adf66688e 100644 --- a/fnet/src/vespa/fnet/frt/invoker.h +++ b/fnet/src/vespa/fnet/frt/invoker.h @@ -59,8 +59,7 @@ public: //----------------------------------------------------------------------------- -class FRT_RPCInvoker : public FastOS_Runnable, - public FRT_IReturnHandler +class FRT_RPCInvoker : public FRT_IReturnHandler { private: FRT_RPCRequest *_req; @@ -76,15 +75,13 @@ public: bool noReply); void ForceMethod(FRT_Method *method) { _method = method; } - bool IsInstant(); FRT_RPCRequest *GetRequest() { return _req; } void HandleDone(bool freeChannel); - bool Invoke(bool freeChannel); + bool Invoke(); void HandleReturn() override; FNET_Connection *GetConnection() override; - void Run(FastOS_ThreadInterface *, void *) override; }; //----------------------------------------------------------------------------- diff --git a/fnet/src/vespa/fnet/frt/reflection.cpp b/fnet/src/vespa/fnet/frt/reflection.cpp index 4285c512ebf..305294f4a3c 100644 --- a/fnet/src/vespa/fnet/frt/reflection.cpp +++ b/fnet/src/vespa/fnet/frt/reflection.cpp @@ -6,13 +6,12 @@ #include "supervisor.h" FRT_Method::FRT_Method(const char * name, const char * paramSpec, const char * returnSpec, - bool instant, FRT_METHOD_PT method, FRT_Invokable * handler) + FRT_METHOD_PT method, FRT_Invokable * handler) : _hashNext(nullptr), _listNext(nullptr), _name(strdup(name)), _paramSpec(strdup(paramSpec)), _returnSpec(strdup(returnSpec)), - _instant(instant), _method(method), _handler(handler), _docLen(0), @@ -171,7 +170,6 @@ void FRT_ReflectionBuilder::DefineMethod(const char *name, const char *paramSpec, const char *returnSpec, - bool instant, FRT_METHOD_PT method, FRT_Invokable *handler) { @@ -182,7 +180,6 @@ FRT_ReflectionBuilder::DefineMethod(const char *name, _method = new FRT_Method(name, paramSpec, returnSpec, - instant, method, handler); _lookup->AddMethod(_method); diff --git a/fnet/src/vespa/fnet/frt/reflection.h b/fnet/src/vespa/fnet/frt/reflection.h index 466e58413e9..5189cf81d0a 100644 --- a/fnet/src/vespa/fnet/frt/reflection.h +++ b/fnet/src/vespa/fnet/frt/reflection.h @@ -19,7 +19,6 @@ private: char *_name; // method name char *_paramSpec; // method parameter spec char *_returnSpec; // method return spec - bool _instant; // method is instant ? FRT_METHOD_PT _method; // method pointer FRT_Invokable *_handler; // method handler uint32_t _docLen; // method documentation length @@ -32,7 +31,6 @@ public: FRT_Method(const char *name, const char *paramSpec, const char *returnSpec, - bool instant, FRT_METHOD_PT method, FRT_Invokable *handler); @@ -42,7 +40,6 @@ public: const char *GetName() { return _name; } const char *GetParamSpec() { return _paramSpec; } const char *GetReturnSpec() { return _returnSpec; } - bool IsInstant() { return _instant; } FRT_METHOD_PT GetMethod() { return _method; } FRT_Invokable *GetHandler() { return _handler; } void SetDocumentation(FRT_Values *values); @@ -121,7 +118,6 @@ public: void DefineMethod(const char *name, const char *paramSpec, const char *returnSpec, - bool instant, FRT_METHOD_PT method, FRT_Invokable *handler); void MethodDesc(const char *desc); diff --git a/fnet/src/vespa/fnet/frt/rpcrequest.h b/fnet/src/vespa/fnet/frt/rpcrequest.h index a10653ce2f6..cc871e7ac0c 100644 --- a/fnet/src/vespa/fnet/frt/rpcrequest.h +++ b/fnet/src/vespa/fnet/frt/rpcrequest.h @@ -133,7 +133,7 @@ public: FNET_Packet *CreateReplyPacket(); void SetDetachedPT(bool *detachedPT) { _detachedPT = detachedPT; } - void Detach() { *_detachedPT = true; } + FRT_RPCRequest *Detach() { *_detachedPT = true; return this; } void SetAbortHandler(FRT_IAbortHandler *handler) { _abortHandler = handler; } void SetReturnHandler(FRT_IReturnHandler *handler) { _returnHandler = handler; } diff --git a/fnet/src/vespa/fnet/frt/supervisor.cpp b/fnet/src/vespa/fnet/frt/supervisor.cpp index 0d2a9b68f17..e509223c005 100644 --- a/fnet/src/vespa/fnet/frt/supervisor.cpp +++ b/fnet/src/vespa/fnet/frt/supervisor.cpp @@ -26,7 +26,8 @@ FRT_Supervisor::FRT_Supervisor(FNET_Transport *transport, } -FRT_Supervisor::FRT_Supervisor(uint32_t threadStackSize, +FRT_Supervisor::FRT_Supervisor(vespalib::CryptoEngine::SP crypto, + uint32_t threadStackSize, uint32_t maxThreads) : _transport(nullptr), _threadPool(nullptr), @@ -39,7 +40,7 @@ FRT_Supervisor::FRT_Supervisor(uint32_t threadStackSize, _connHooks(*this), _methodMismatchHook(nullptr) { - _transport = new FNET_Transport(); + _transport = new FNET_Transport(std::move(crypto), 1); assert(_transport != nullptr); if (threadStackSize > 0) { _threadPool = new FastOS_ThreadPool(threadStackSize, maxThreads); @@ -90,22 +91,6 @@ FRT_Supervisor::GetListenPort() const } -bool -FRT_Supervisor::RunInvocation(FRT_RPCInvoker *invoker) -{ - // XXX: implement queue with max length + max # threads - - if (_threadPool == nullptr || - _threadPool->NewThread(invoker) == nullptr) - { - invoker->GetRequest()->SetError(FRTE_RPC_OVERLOAD, - "Could not start thread"); - return false; - } - return true; -} - - FRT_Target * FRT_Supervisor::GetTarget(const char *spec) { @@ -178,7 +163,7 @@ FRT_Supervisor::SetMethodMismatchHook(FRT_METHOD_PT method, { delete _methodMismatchHook; _methodMismatchHook = new FRT_Method("frt.hook.methodMismatch", "*", "*", - true, method, handler); + method, handler); assert(_methodMismatchHook != nullptr); } @@ -283,25 +268,17 @@ FRT_Supervisor::HandlePacket(FNET_Packet *packet, FNET_Context context) && _methodMismatchHook != nullptr) { invoker->ForceMethod(_methodMismatchHook); - return (invoker->Invoke(false)) ? + return (invoker->Invoke()) ? FNET_FREE_CHANNEL : FNET_CLOSE_CHANNEL; } invoker->HandleDone(false); return FNET_FREE_CHANNEL; - } else if (invoker->IsInstant()) { - - return (invoker->Invoke(false)) ? - FNET_FREE_CHANNEL : FNET_CLOSE_CHANNEL; - } else { - if (!RunInvocation(invoker)) { - invoker->HandleDone(false); - return FNET_FREE_CHANNEL; - } - return FNET_CLOSE_CHANNEL; + return (invoker->Invoke()) ? + FNET_FREE_CHANNEL : FNET_CLOSE_CHANNEL; } } @@ -348,17 +325,17 @@ FRT_Supervisor::RPCHooks::InitRPC(FRT_Supervisor *supervisor) { FRT_ReflectionBuilder rb(supervisor); //--------------------------------------------------------------------------- - rb.DefineMethod("frt.rpc.ping", "", "", true, + rb.DefineMethod("frt.rpc.ping", "", "", FRT_METHOD(FRT_Supervisor::RPCHooks::RPC_Ping), this); rb.MethodDesc("Method that may be used to check if the server is online"); //--------------------------------------------------------------------------- - rb.DefineMethod("frt.rpc.echo", "*", "*", true, + rb.DefineMethod("frt.rpc.echo", "*", "*", FRT_METHOD(FRT_Supervisor::RPCHooks::RPC_Echo), this); rb.MethodDesc("Echo the parameters as return values"); rb.ParamDesc("params", "Any set of parameters"); rb.ReturnDesc("return", "The parameter values"); //--------------------------------------------------------------------------- - rb.DefineMethod("frt.rpc.getMethodList", "", "SSS", true, + rb.DefineMethod("frt.rpc.getMethodList", "", "SSS", FRT_METHOD(FRT_Supervisor::RPCHooks::RPC_GetMethodList), this); rb.MethodDesc("Obtain a list of all available methods"); @@ -366,7 +343,7 @@ FRT_Supervisor::RPCHooks::InitRPC(FRT_Supervisor *supervisor) rb.ReturnDesc("params", "Method parameter types"); rb.ReturnDesc("return", "Method return types"); //--------------------------------------------------------------------------- - rb.DefineMethod("frt.rpc.getMethodInfo", "s", "sssSSSS", true, + rb.DefineMethod("frt.rpc.getMethodInfo", "s", "sssSSSS", FRT_METHOD(FRT_Supervisor::RPCHooks::RPC_GetMethodInfo), this); rb.MethodDesc("Obtain detailed information about a single method"); @@ -447,7 +424,7 @@ FRT_Supervisor::ConnHooks::SetSessionInitHook(FRT_METHOD_PT method, { delete _sessionInitHook; _sessionInitHook = new FRT_Method("frt.hook.sessionInit", "", "", - true, method, handler); + method, handler); assert(_sessionInitHook != nullptr); } @@ -458,7 +435,7 @@ FRT_Supervisor::ConnHooks::SetSessionDownHook(FRT_METHOD_PT method, { delete _sessionDownHook; _sessionDownHook = new FRT_Method("frt.hook.sessionDown", "", "", - true, method, handler); + method, handler); assert(_sessionDownHook != nullptr); } @@ -469,7 +446,7 @@ FRT_Supervisor::ConnHooks::SetSessionFiniHook(FRT_METHOD_PT method, { delete _sessionFiniHook; _sessionFiniHook = new FRT_Method("frt.hook.sessionFini", "", "", - true, method, handler); + method, handler); assert(_sessionFiniHook != nullptr); } diff --git a/fnet/src/vespa/fnet/frt/supervisor.h b/fnet/src/vespa/fnet/frt/supervisor.h index e818e506186..dc7fb496239 100644 --- a/fnet/src/vespa/fnet/frt/supervisor.h +++ b/fnet/src/vespa/fnet/frt/supervisor.h @@ -9,6 +9,7 @@ #include <vespa/fnet/ipackethandler.h> #include <vespa/fnet/connection.h> #include <vespa/fnet/simplepacketstreamer.h> +#include <vespa/vespalib/net/crypto_engine.h> class FNET_Transport; class FRT_Target; @@ -82,10 +83,10 @@ private: FRT_Supervisor &operator=(const FRT_Supervisor &); public: - FRT_Supervisor(FNET_Transport *transport, - FastOS_ThreadPool *threadPool); - FRT_Supervisor(uint32_t threadStackSize = 65000, - uint32_t maxThreads = 0); + FRT_Supervisor(FNET_Transport *transport, FastOS_ThreadPool *threadPool); + FRT_Supervisor(vespalib::CryptoEngine::SP crypto, uint32_t threadStackSize = 65000, uint32_t maxThreads = 0); + FRT_Supervisor(uint32_t threadStackSize = 65000, uint32_t maxThreads = 0) + : FRT_Supervisor(vespalib::CryptoEngine::get_default(), threadStackSize, maxThreads) {} virtual ~FRT_Supervisor(); bool StandAlone() { return _standAlone; } @@ -98,8 +99,6 @@ public: bool Listen(int port); uint32_t GetListenPort() const; - bool RunInvocation(FRT_RPCInvoker *invoker); - FRT_Target *GetTarget(const char *spec); FRT_Target *Get2WayTarget(const char *spec, FNET_Context connContext = FNET_Context()); diff --git a/fnet/src/vespa/fnet/iocomponent.cpp b/fnet/src/vespa/fnet/iocomponent.cpp index 8276da57e2e..d4244cbf204 100644 --- a/fnet/src/vespa/fnet/iocomponent.cpp +++ b/fnet/src/vespa/fnet/iocomponent.cpp @@ -12,7 +12,6 @@ FNET_IOComponent::FNET_IOComponent(FNET_TransportThread *owner, : _ioc_next(nullptr), _ioc_prev(nullptr), _ioc_owner(owner), - _ioc_counters(_ioc_owner->GetStatCounters()), _ioc_socket_fd(socket_fd), _ioc_selector(nullptr), _ioc_spec(nullptr), @@ -20,9 +19,7 @@ FNET_IOComponent::FNET_IOComponent(FNET_TransportThread *owner, _ioc_timestamp(fastos::ClockSystem::now()), _ioc_lock(), _ioc_cond(), - _ioc_refcnt(1), - _ioc_directPacketWriteCnt(0), - _ioc_directDataWriteCnt(0) + _ioc_refcnt(1) { _ioc_spec = strdup(spec); assert(_ioc_spec != nullptr); diff --git a/fnet/src/vespa/fnet/iocomponent.h b/fnet/src/vespa/fnet/iocomponent.h index a48930428c7..901c3d1a5d0 100644 --- a/fnet/src/vespa/fnet/iocomponent.h +++ b/fnet/src/vespa/fnet/iocomponent.h @@ -2,14 +2,12 @@ #pragma once -#include "stats.h" #include <vespa/fastos/timestamp.h> #include <vespa/vespalib/net/selector.h> #include <mutex> #include <condition_variable> class FNET_TransportThread; -class FNET_StatCounters; class FNET_Config; /** @@ -45,7 +43,6 @@ protected: FNET_IOComponent *_ioc_next; // next in list FNET_IOComponent *_ioc_prev; // prev in list FNET_TransportThread *_ioc_owner; // owner(TransportThread) ref. - FNET_StatCounters *_ioc_counters; // stat counters int _ioc_socket_fd; // source of events. Selector *_ioc_selector; // attached event selector char *_ioc_spec; // connect/listen spec @@ -55,10 +52,6 @@ protected: std::condition_variable _ioc_cond; // synchronization uint32_t _ioc_refcnt; // reference counter - // direct write stats kept locally - uint32_t _ioc_directPacketWriteCnt; - uint32_t _ioc_directDataWriteCnt; - public: /** @@ -158,90 +151,6 @@ public: **/ void UpdateTimeOut(); - - /** - * Count packet read(s). This is a proxy method updating the stat - * counters associated with the owning transport object. - * - * @param cnt the number of packets read (default is 1). - **/ - void CountPacketRead(uint32_t cnt = 1) - { _ioc_counters->CountPacketRead(cnt); } - - - /** - * Count packet write(s). This is a proxy method updating the stat - * counters associated with the owning transport object. - * - * @param cnt the number of packets written (default is 1). - **/ - void CountPacketWrite(uint32_t cnt = 1) - { _ioc_counters->CountPacketWrite(cnt); } - - - /** - * Count direct packet write(s). This method will increase an - * internal counter. The shared stat counters may not be used - * because this method may be called by other threads than the - * transport thread. Note: The IO Component should be locked when - * this method is called. - * - * @param cnt the number of packets written (default is 1). - **/ - void CountDirectPacketWrite(uint32_t cnt = 1) - { _ioc_directPacketWriteCnt += cnt; } - - - /** - * Count read data. This is a proxy method updating the stat - * counters associated with the owning transport object. - * - * @param bytes the number of bytes read. - **/ - void CountDataRead(uint32_t bytes) - { _ioc_counters->CountDataRead(bytes); } - - - /** - * Count written data. This is a proxy method updating the stat - * counters associated with the owning transport object. - * - * @param bytes the number of bytes written. - **/ - void CountDataWrite(uint32_t bytes) - { _ioc_counters->CountDataWrite(bytes); } - - - /** - * Count direct written data. This method will increase an - * internal counter. The shared stat counters may not be used - * because this method may be called by other threads than the - * transport thread. Note: The IO Component should be locked when - * this method is called. - * - * @param bytes the number of bytes written. - **/ - void CountDirectDataWrite(uint32_t bytes) - { _ioc_directDataWriteCnt += bytes; } - - - /** - * Transfer the direct write stats held by this IO Component over to - * the stat counters associated with the owning transport object - * (and reset the local counters). Note: This method should only be - * called from the transport thread while having the lock on this IO - * Component. Note: This method is called from the transport loop - * and should generally not be called by application code. - **/ - void FlushDirectWriteStats() - { - _ioc_counters->CountPacketWrite(_ioc_directPacketWriteCnt); - _ioc_counters->CountDataWrite(_ioc_directDataWriteCnt); - _ioc_directPacketWriteCnt = 0; - _ioc_directDataWriteCnt = 0; - } - - /** * Attach an event selector to this component. Before deleting an * IOC, one must first call detach_selector to detach the diff --git a/fnet/src/vespa/fnet/stats.cpp b/fnet/src/vespa/fnet/stats.cpp deleted file mode 100644 index f156fe4afe7..00000000000 --- a/fnet/src/vespa/fnet/stats.cpp +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include "stats.h" - -#include <vespa/log/log.h> -LOG_SETUP(".fnet"); - -FNET_StatCounters::FNET_StatCounters() - : _eventLoopCnt(0), - _eventCnt(0), - _ioEventCnt(0), - _packetReadCnt(0), - _packetWriteCnt(0), - _dataReadCnt(0), - _dataWriteCnt(0) -{ -} - - -FNET_StatCounters::~FNET_StatCounters() -{ -} - - -void -FNET_StatCounters::Clear() -{ - _eventLoopCnt = 0; - _eventCnt = 0; - _ioEventCnt = 0; - _packetReadCnt = 0; - _packetWriteCnt = 0; - _dataReadCnt = 0; - _dataWriteCnt = 0; -} - -//----------------------------------------------- - -FNET_Stats::FNET_Stats() - : _eventLoopRate(0), - _eventRate(0), - _ioEventRate(0), - _packetReadRate(0), - _packetWriteRate(0), - _dataReadRate(0), - _dataWriteRate(0) -{ -} - - -FNET_Stats::~FNET_Stats() -{ -} - - -void -FNET_Stats::Update(FNET_StatCounters *count, double secs) -{ - _eventLoopRate = (float)(FNET_STATS_OLD_FACTOR * _eventLoopRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_eventLoopCnt / secs))); - _eventRate = (float)(FNET_STATS_OLD_FACTOR * _eventRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_eventCnt / secs))); - _ioEventRate = (float)(FNET_STATS_OLD_FACTOR * _ioEventRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_ioEventCnt / secs))); - - _packetReadRate = (float)(FNET_STATS_OLD_FACTOR * _packetReadRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_packetReadCnt / secs))); - _packetWriteRate = (float)(FNET_STATS_OLD_FACTOR * _packetWriteRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_packetWriteCnt / secs))); - - _dataReadRate = (float)(FNET_STATS_OLD_FACTOR * _dataReadRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_dataReadCnt / (1000.0 * secs)))); - _dataWriteRate = (float)(FNET_STATS_OLD_FACTOR * _dataWriteRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_dataWriteCnt / (1000.0 * secs)))); -} - - -void -FNET_Stats::Log() -{ - LOG(info, "events[/s][loop/int/io][%.1f/%.1f/%.1f] " - "packets[/s][r/w][%.1f/%.1f] " - "data[kB/s][r/w][%.2f/%.2f]", - _eventLoopRate, - _eventRate, - _ioEventRate, - _packetReadRate, - _packetWriteRate, - _dataReadRate, - _dataWriteRate); -} diff --git a/fnet/src/vespa/fnet/stats.h b/fnet/src/vespa/fnet/stats.h deleted file mode 100644 index 76651393165..00000000000 --- a/fnet/src/vespa/fnet/stats.h +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#pragma once - -#include <cstdint> - -/** - * This class is used internally by @ref FNET_Transport objects to - * aggregate FNET statistics. The actual statistics are located in the - * @ref FNET_Stats class. - **/ -class FNET_StatCounters -{ -public: - uint32_t _eventLoopCnt; // # event loop iterations - uint32_t _eventCnt; // # internal events - uint32_t _ioEventCnt; // # IO events - uint32_t _packetReadCnt; // # packets read - uint32_t _packetWriteCnt; // # packets written - uint32_t _dataReadCnt; // # bytes read - uint32_t _dataWriteCnt; // # bytes written - - FNET_StatCounters(); - ~FNET_StatCounters(); - - void Clear(); - void CountEventLoop(uint32_t cnt) { _eventLoopCnt += cnt; } - void CountEvent(uint32_t cnt) { _eventCnt += cnt; } - void CountIOEvent(uint32_t cnt) { _ioEventCnt += cnt; } - void CountPacketRead(uint32_t cnt) { _packetReadCnt += cnt; } - void CountPacketWrite(uint32_t cnt) { _packetWriteCnt += cnt; } - void CountDataRead(uint32_t bytes) { _dataReadCnt += bytes; } - void CountDataWrite(uint32_t bytes) { _dataWriteCnt += bytes; } -}; - -//----------------------------------------------- - -#define FNET_STATS_OLD_FACTOR 0.5 -#define FNET_STATS_NEW_FACTOR 0.5 - -/** - * This class contains various FNET statistics. The statistics for a - * @ref FNET_Transport object may be obtained by invoking the GetStats - * method on it. - **/ -class FNET_Stats -{ -public: - /** - * Event loop iterations per second. - **/ - float _eventLoopRate; // loop iterations/s - - /** - * Internal events handled per second. - **/ - float _eventRate; // internal-events/s - - /** - * IO events handled per second. - **/ - float _ioEventRate; // IO-events/s - - /** - * Packets read per second. - **/ - float _packetReadRate; // packets/s - - /** - * Packets written per second. - **/ - float _packetWriteRate; // packets/s - - /** - * Data read per second (in kB). - **/ - float _dataReadRate; // kB/s - - /** - * Data written per second (in kB). - **/ - float _dataWriteRate; // kB/s - - FNET_Stats(); - ~FNET_Stats(); - - /** - * Update statistics. The new statistics are calculated based on - * both the current values and the input count structure indicating - * what has happened since the last statistics update. - * - * @param count what has happened since last statistics update. - * @param secs number of seconds since last statistics update. - **/ - void Update(FNET_StatCounters *count, double secs); - - /** - * Invoking this method will generate a log message of type - * FNET_INFO showing the values held by this object. - **/ - void Log(); -}; - diff --git a/fnet/src/vespa/fnet/transport.cpp b/fnet/src/vespa/fnet/transport.cpp index 0a484f9caaa..dfeb8d03436 100644 --- a/fnet/src/vespa/fnet/transport.cpp +++ b/fnet/src/vespa/fnet/transport.cpp @@ -112,14 +112,6 @@ FNET_Transport::SetMaxOutputBufferSize(uint32_t bytes) } void -FNET_Transport::SetDirectWrite(bool directWrite) -{ - for (const auto &thread: _threads) { - thread->SetDirectWrite(directWrite); - } -} - -void FNET_Transport::SetTCPNoDelay(bool noDelay) { for (const auto &thread: _threads) { @@ -128,14 +120,6 @@ FNET_Transport::SetTCPNoDelay(bool noDelay) } void -FNET_Transport::SetLogStats(bool logStats) -{ - for (const auto &thread: _threads) { - thread->SetLogStats(logStats); - } -} - -void FNET_Transport::sync() { for (const auto &thread: _threads) { diff --git a/fnet/src/vespa/fnet/transport.h b/fnet/src/vespa/fnet/transport.h index a9c9eee2296..15e69bd66a6 100644 --- a/fnet/src/vespa/fnet/transport.h +++ b/fnet/src/vespa/fnet/transport.h @@ -187,14 +187,6 @@ public: void SetMaxOutputBufferSize(uint32_t bytes); /** - * Enable or disable the direct write optimization. This is - * enabled by default and favors low latency above throughput. - * - * @param directWrite enable direct write? - **/ - void SetDirectWrite(bool directWrite); - - /** * Enable or disable use of the TCP_NODELAY flag with sockets * created by this transport object. * @@ -203,14 +195,6 @@ public: void SetTCPNoDelay(bool noDelay); /** - * Enable or disable logging of FNET statistics. This feature is - * disabled by default. - * - * @param logStats true if stats should be logged. - **/ - void SetLogStats(bool logStats); - - /** * Synchronize with all transport threads. This method will block * until all events posted before this method was invoked has been * processed. If a transport thread has been shut down (or is in diff --git a/fnet/src/vespa/fnet/transport_thread.cpp b/fnet/src/vespa/fnet/transport_thread.cpp index 443c90f1af4..2c0d00b22f3 100644 --- a/fnet/src/vespa/fnet/transport_thread.cpp +++ b/fnet/src/vespa/fnet/transport_thread.cpp @@ -31,15 +31,6 @@ struct Sync : public FNET_IExecutable } // namespace<unnamed> -#ifndef IAM_DOXYGEN -void -FNET_TransportThread::StatsTask::PerformTask() -{ - _transport->UpdateStats(); - Schedule(5.0); -} -#endif - void FNET_TransportThread::AddComponent(FNET_IOComponent *comp) { @@ -161,27 +152,31 @@ FNET_TransportThread::DiscardEvent(FNET_ControlPacket *cpacket, void -FNET_TransportThread::UpdateStats() +FNET_TransportThread::handle_add_cmd(FNET_IOComponent *ioc) { - _now.SetNow(); // trade some overhead for better stats - double ms = _now.MilliSecs() - _statTime.MilliSecs(); - _statTime = _now; - for (FNET_IOComponent *comp = _componentsHead; - comp != nullptr; comp = comp->_ioc_next) - { - auto guard(comp->getGuard()); - comp->FlushDirectWriteStats(); - } - { - std::lock_guard<std::mutex> guard(_lock); - _stats.Update(&_counters, ms / 1000.0); + if (ioc->handle_add_event()) { + AddComponent(ioc); + ioc->_flags._ioc_added = true; + ioc->attach_selector(_selector); + } else { + ioc->Close(); + AddDeleteComponent(ioc); } - _counters.Clear(); +} + - if (_config._logStats) - _stats.Log(); +void +FNET_TransportThread::handle_close_cmd(FNET_IOComponent *ioc) +{ + if (ioc->_flags._ioc_added) { + RemoveComponent(ioc); + ioc->SubRef(); + } + ioc->Close(); + AddDeleteComponent(ioc); } + extern "C" { static void pipehandler(int) @@ -209,10 +204,6 @@ FNET_TransportThread::FNET_TransportThread(FNET_Transport &owner_in) _startTime(), _now(), _scheduler(&_now), - _counters(), - _stats(), - _statsTask(&_scheduler, this), - _statTime(), _config(), _componentsHead(nullptr), _timeOutHead(nullptr), @@ -430,8 +421,6 @@ FNET_TransportThread::InitEventLoop() } _now.SetNow(); _startTime = _now; - _statTime = _now; - _statsTask.Schedule(5.0); return true; } @@ -441,7 +430,7 @@ FNET_TransportThread::handle_wakeup() { { std::lock_guard<std::mutex> guard(_lock); - CountEvent(_queue.FlushPackets_NoLock(&_myQueue)); + _queue.FlushPackets_NoLock(&_myQueue); } FNET_Context context; @@ -460,14 +449,7 @@ FNET_TransportThread::handle_wakeup() switch (packet->GetCommand()) { case FNET_ControlPacket::FNET_CMD_IOC_ADD: - if (context._value.IOC->handle_add_event()) { - AddComponent(context._value.IOC); - context._value.IOC->_flags._ioc_added = true; - context._value.IOC->attach_selector(_selector); - } else { - context._value.IOC->Close(); - AddDeleteComponent(context._value.IOC); - } + handle_add_cmd(context._value.IOC); break; case FNET_ControlPacket::FNET_CMD_IOC_ENABLE_READ: context._value.IOC->EnableReadEvent(true); @@ -479,19 +461,18 @@ FNET_TransportThread::handle_wakeup() break; case FNET_ControlPacket::FNET_CMD_IOC_ENABLE_WRITE: context._value.IOC->EnableWriteEvent(true); - context._value.IOC->SubRef(); + if (context._value.IOC->HandleWriteEvent()) { + context._value.IOC->SubRef(); + } else { + handle_close_cmd(context._value.IOC); + } break; case FNET_ControlPacket::FNET_CMD_IOC_DISABLE_WRITE: context._value.IOC->EnableWriteEvent(false); context._value.IOC->SubRef(); break; case FNET_ControlPacket::FNET_CMD_IOC_CLOSE: - if (context._value.IOC->_flags._ioc_added) { - RemoveComponent(context._value.IOC); - context._value.IOC->SubRef(); - } - context._value.IOC->Close(); - AddDeleteComponent(context._value.IOC); + handle_close_cmd(context._value.IOC); break; } } @@ -540,7 +521,6 @@ FNET_TransportThread::EventLoopIteration() // obtain I/O events _selector.poll(msTimeout); - CountEventLoop(); // sample current time (performed once per event loop iteration) _now.SetNow(); @@ -554,7 +534,6 @@ FNET_TransportThread::EventLoopIteration() #endif // handle wakeup and io-events - CountIOEvent(_selector.num_events()); _selector.dispatch(*this); // handle IOC time-outs @@ -585,9 +564,6 @@ FNET_TransportThread::EventLoopIteration() if (_finished) return false; - // unschedule statistics task - _statsTask.Kill(); - // flush event queue { std::lock_guard<std::mutex> guard(_lock); diff --git a/fnet/src/vespa/fnet/transport_thread.h b/fnet/src/vespa/fnet/transport_thread.h index 8af0642eebe..408d20619d2 100644 --- a/fnet/src/vespa/fnet/transport_thread.h +++ b/fnet/src/vespa/fnet/transport_thread.h @@ -6,7 +6,6 @@ #include "config.h" #include "task.h" #include "packetqueue.h" -#include "stats.h" #include <vespa/fastos/thread.h> #include <vespa/fastos/time.h> #include <vespa/vespalib/net/socket_handle.h> @@ -31,31 +30,11 @@ class FNET_TransportThread : public FastOS_Runnable public: using Selector = vespalib::Selector<FNET_IOComponent>; -#ifndef IAM_DOXYGEN - class StatsTask : public FNET_Task - { - private: - FNET_TransportThread *_transport; - StatsTask(const StatsTask &); - StatsTask &operator=(const StatsTask &); - public: - StatsTask(FNET_Scheduler *scheduler, - FNET_TransportThread *transport) : FNET_Task(scheduler), - _transport(transport) {} - void PerformTask() override; - }; - friend class FNET_TransportThread::StatsTask; -#endif // DOXYGEN - private: FNET_Transport &_owner; // owning transport layer FastOS_Time _startTime; // when event loop started FastOS_Time _now; // current time sampler FNET_Scheduler _scheduler; // transport thread scheduler - FNET_StatCounters _counters; // stat counters - FNET_Stats _stats; // current stats - StatsTask _statsTask; // stats task - FastOS_Time _statTime; // last stat update FNET_Config _config; // FNET configuration [static] FNET_IOComponent *_componentsHead; // I/O component list head FNET_IOComponent *_timeOutHead; // first IOC in list to time out @@ -156,49 +135,6 @@ private: /** - * Update internal FNET statistics. This method is called regularly - * by the statistics update task. - **/ - void UpdateStats(); - - - /** - * Obtain a reference to the stat counters used by this transport - * object. - * - * @return stat counters for this transport object. - **/ - FNET_StatCounters *GetStatCounters() { return &_counters; } - - - /** - * Count event loop iteration(s). - * - * @param cnt event loop iterations (default is 1). - **/ - void CountEventLoop(uint32_t cnt = 1) - { _counters.CountEventLoop(cnt); } - - - /** - * Count internal event(s). - * - * @param cnt number of internal events. - **/ - void CountEvent(uint32_t cnt) - { _counters.CountEvent(cnt); } - - - /** - * Count IO events. - * - * @param cnt number of IO events. - **/ - void CountIOEvent(uint32_t cnt) - { _counters.CountIOEvent(cnt); } - - - /** * Obtain a reference to the object holding the configuration for * this transport object. * @@ -207,6 +143,9 @@ private: FNET_Config *GetConfig() { return &_config; } + void handle_add_cmd(FNET_IOComponent *ioc); + void handle_close_cmd(FNET_IOComponent *ioc); + public: /** * Construct a transport object. To activate your newly created @@ -345,18 +284,6 @@ public: void SetMaxOutputBufferSize(uint32_t bytes) { _config._maxOutputBufferSize = bytes; } - - /** - * Enable or disable the direct write optimization. This is - * enabled by default and favors low latency above throughput. - * - * @param directWrite enable direct write? - **/ - void SetDirectWrite(bool directWrite) { - _config._directWrite = directWrite; - } - - /** * Enable or disable use of the TCP_NODELAY flag with sockets * created by this transport object. @@ -367,15 +294,6 @@ public: /** - * Enable or disable logging of FNET statistics. This feature is - * disabled by default. - * - * @param logStats true if stats should be logged. - **/ - void SetLogStats(bool logStats) { _config._logStats = logStats; } - - - /** * Add an I/O component to the working set of this transport * object. Note that the actual work is performed by the transport * thread. This method simply posts an event on the transport thread diff --git a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilterTest.java b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilterTest.java index be5ab9c1d77..fdab450b435 100644 --- a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilterTest.java +++ b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilterTest.java @@ -10,9 +10,9 @@ import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzPrincipal; import com.yahoo.vespa.athenz.api.AthenzUser; import com.yahoo.vespa.athenz.api.NToken; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateBuilder; import com.yahoo.vespa.athenz.utils.ntoken.NTokenValidator; import org.junit.Before; import org.junit.Test; @@ -22,6 +22,7 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.UncheckedIOException; +import java.math.BigInteger; import java.security.KeyPair; import java.security.cert.X509Certificate; import java.time.Duration; @@ -30,7 +31,8 @@ import java.util.Objects; import java.util.Set; import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED; -import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_ECDSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_RSA; import static java.util.Collections.emptyList; import static java.util.Collections.singleton; import static java.util.Collections.singletonList; @@ -189,11 +191,11 @@ public class AthenzPrincipalFilterTest { } private static X509Certificate createSelfSignedCertificate(AthenzIdentity identity) { - KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 512); + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); X500Principal x500Name = new X500Principal("CN="+ identity.getFullName()); Instant now = Instant.now(); return X509CertificateBuilder - .fromKeypair(keyPair, x500Name, now, now.plus(Duration.ofDays(30)), SHA256_WITH_RSA, 1) + .fromKeypair(keyPair, x500Name, now, now.plus(Duration.ofDays(30)), SHA256_WITH_ECDSA, BigInteger.ONE) .build(); } diff --git a/jdisc_http_service/pom.xml b/jdisc_http_service/pom.xml index f41994c4916..879036db355 100644 --- a/jdisc_http_service/pom.xml +++ b/jdisc_http_service/pom.xml @@ -16,6 +16,7 @@ <packaging>container-plugin</packaging> <name>${project.artifactId}</name> <dependencies> + <!-- PROVIDED SCOPE --> <dependency> <groupId>org.bouncycastle</groupId> <artifactId>bcpkix-jdk15on</artifactId> @@ -33,6 +34,56 @@ <classifier>no_aop</classifier> </dependency> <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>jdisc_jetty</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>config-lib</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>defaults</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>jdisc_core</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>annotations</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>component</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>container-accesslogging</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>vespajlib</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + + <!-- TEST SCOPE --> + <dependency> <groupId>org.apache.commons</groupId> <artifactId>commons-lang3</artifactId> <scope>test</scope> @@ -53,12 +104,6 @@ <scope>test</scope> </dependency> <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>jdisc_jetty</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> - <dependency> <groupId>org.mockito</groupId> <artifactId>mockito-core</artifactId> <scope>test</scope> @@ -80,58 +125,10 @@ </exclusions> </dependency> <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>config-lib</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>defaults</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>jdisc_core</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>annotations</artifactId> - <version>${project.version}</version> - <scope>provided</scope> + <groupId>org.springframework</groupId> + <artifactId>spring-test</artifactId> + <scope>test</scope> </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>component</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>container-accesslogging</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>vespajlib</artifactId> - <version>${project.version}</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.jetbrains</groupId> - <artifactId>annotations</artifactId> - <version>13.0</version> - <scope>provided</scope> - </dependency> - <dependency> - <groupId>org.springframework</groupId> - <artifactId>spring-test</artifactId> - <scope>test</scope> - </dependency> </dependencies> <build> <plugins> diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactory.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactory.java index 6e3b6a65c51..f9892759fbd 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactory.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactory.java @@ -2,18 +2,9 @@ package com.yahoo.jdisc.http.server.jetty; import com.google.inject.Inject; -import com.yahoo.config.InnerNode; import com.yahoo.jdisc.Metric; import com.yahoo.jdisc.http.ConnectorConfig; -import com.yahoo.jdisc.http.ConnectorConfig.Ssl; -import com.yahoo.jdisc.http.ConnectorConfig.Ssl.ExcludeCipherSuite; -import com.yahoo.jdisc.http.ConnectorConfig.Ssl.ExcludeProtocol; -import com.yahoo.jdisc.http.ConnectorConfig.Ssl.IncludeCipherSuite; -import com.yahoo.jdisc.http.ConnectorConfig.Ssl.IncludeProtocol; -import com.yahoo.jdisc.http.ssl.DefaultSslKeyStoreContext; -import com.yahoo.jdisc.http.ssl.DefaultSslTrustStoreContext; -import com.yahoo.jdisc.http.ssl.SslKeyStoreConfigurator; -import com.yahoo.jdisc.http.ssl.SslTrustStoreConfigurator; +import com.yahoo.jdisc.http.ssl.SslContextFactoryProvider; import org.eclipse.jetty.http.HttpVersion; import org.eclipse.jetty.server.HttpConfiguration; import org.eclipse.jetty.server.HttpConnectionFactory; @@ -24,9 +15,6 @@ import org.eclipse.jetty.server.SslConnectionFactory; import org.eclipse.jetty.util.ssl.SslContextFactory; import java.nio.channels.ServerSocketChannel; -import java.util.List; -import java.util.function.BiConsumer; -import java.util.function.Function; /** * @author Einar M R Rosenvinge @@ -35,16 +23,13 @@ import java.util.function.Function; public class ConnectorFactory { private final ConnectorConfig connectorConfig; - private final SslKeyStoreConfigurator sslKeyStoreConfigurator; - private final SslTrustStoreConfigurator sslTrustStoreConfigurator; + private final SslContextFactoryProvider sslContextFactoryProvider; @Inject public ConnectorFactory(ConnectorConfig connectorConfig, - SslKeyStoreConfigurator sslKeyStoreConfigurator, - SslTrustStoreConfigurator sslTrustStoreConfigurator) { + SslContextFactoryProvider sslContextFactoryProvider) { this.connectorConfig = connectorConfig; - this.sslKeyStoreConfigurator = sslKeyStoreConfigurator; - this.sslTrustStoreConfigurator = sslTrustStoreConfigurator; + this.sslContextFactoryProvider = sslContextFactoryProvider; } public ConnectorConfig getConnectorConfig() { @@ -65,25 +50,11 @@ public class ConnectorFactory { connector.setName(connectorConfig.name()); connector.setAcceptQueueSize(connectorConfig.acceptQueueSize()); connector.setReuseAddress(connectorConfig.reuseAddress()); - double soLingerTimeSeconds = connectorConfig.soLingerTime(); - if (soLingerTimeSeconds == -1) { - setSoLingerTime(connector, -1); - } else { - setSoLingerTime(connector, (int)(soLingerTimeSeconds * 1000.0)); - } connector.setIdleTimeout((long)(connectorConfig.idleTimeout() * 1000.0)); connector.setStopTimeout((long)(connectorConfig.stopTimeout() * 1000.0)); return connector; } - @SuppressWarnings("deprecation") - private static void setSoLingerTime(ServerConnector connector, int milliseconds) { - // TODO: Don't use deprecated methods. Deprecate soLingerTime from connector config - // Jetty says: "don't use as socket close linger time has undefined behavior for non-blocking sockets" - // Jetty implementation is now a noop: https://github.com/eclipse/jetty.project/issues/2468, http://mail.openjdk.java.net/pipermail/nio-dev/2018-June/005195.html - connector.setSoLingerTime(milliseconds); - } - private HttpConnectionFactory newHttpConnectionFactory() { HttpConfiguration httpConfig = new HttpConfiguration(); httpConfig.setSendDateHeader(true); @@ -100,48 +71,8 @@ public class ConnectorFactory { } private SslConnectionFactory newSslConnectionFactory() { - Ssl sslConfig = connectorConfig.ssl(); - - SslContextFactory factory = new JDiscSslContextFactory(); - - sslKeyStoreConfigurator.configure(new DefaultSslKeyStoreContext(factory)); - sslTrustStoreConfigurator.configure(new DefaultSslTrustStoreContext(factory)); - - switch (sslConfig.clientAuth()) { - case NEED_AUTH: - factory.setNeedClientAuth(true); - break; - case WANT_AUTH: - factory.setWantClientAuth(true); - break; - } - - if (!sslConfig.prng().isEmpty()) { - factory.setSecureRandomAlgorithm(sslConfig.prng()); - } - - setStringArrayParameter( - factory, sslConfig.excludeProtocol(), ExcludeProtocol::name, SslContextFactory::setExcludeProtocols); - setStringArrayParameter( - factory, sslConfig.includeProtocol(), IncludeProtocol::name, SslContextFactory::setIncludeProtocols); - setStringArrayParameter( - factory, sslConfig.excludeCipherSuite(), ExcludeCipherSuite::name, SslContextFactory::setExcludeCipherSuites); - setStringArrayParameter( - factory, sslConfig.includeCipherSuite(), IncludeCipherSuite::name, SslContextFactory::setIncludeCipherSuites); - - factory.setKeyManagerFactoryAlgorithm(sslConfig.sslKeyManagerFactoryAlgorithm()); - factory.setProtocol(sslConfig.protocol()); + SslContextFactory factory = sslContextFactoryProvider.getInstance(connectorConfig.name(), connectorConfig.listenPort()); return new SslConnectionFactory(factory, HttpVersion.HTTP_1_1.asString()); } - private static <T extends InnerNode> void setStringArrayParameter(SslContextFactory sslContextFactory, - List<T> configValues, - Function<T, String> nameProperty, - BiConsumer<SslContextFactory, String[]> setter) { - if (!configValues.isEmpty()) { - String[] nameArray = configValues.stream().map(nameProperty).toArray(String[]::new); - setter.accept(sslContextFactory, nameArray); - } - } - } diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java index 3a121e8b1ed..1e92fbef967 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java @@ -41,7 +41,7 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G } private static final String[] HTTP_RESPONSE_GROUPS = { Metrics.RESPONSES_1XX, Metrics.RESPONSES_2XX, Metrics.RESPONSES_3XX, - Metrics.RESPONSES_4XX, Metrics.RESPONSES_5XX }; + Metrics.RESPONSES_4XX, Metrics.RESPONSES_5XX, Metrics.RESPONSES_401, Metrics.RESPONSES_403}; private final AtomicLong inFlight = new AtomicLong(); private final LongAdder statistics[][]; @@ -112,6 +112,9 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G if (group >= 0) { HttpMethod method = getMethod(request); statistics[method.ordinal()][group].increment(); + if (group == 5 || group == 6) { // if 401/403, also increment 4xx + statistics[method.ordinal()][3].increment(); + } } long live = inFlight.decrementAndGet(); @@ -127,15 +130,19 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G } private int groupIndex(Request request) { - if (request.isHandled()) { - int index = (request.getResponse().getStatus() / 100) - 1; // 1xx = 0, 2xx = 1 etc. - if (index < 0 || index > statistics.length) { - return -1; - } else { - return index; - } + int index = request.getResponse().getStatus(); + if (index == 401) { + return 5; + } + if (index == 403) { + return 6; + } + + index = index / 100 - 1; // 1xx = 0, 2xx = 1 etc. + if (index < 0 || index >= statistics[0].length) { + return -1; } else { - return 3; // 4xx + return index; } } @@ -203,4 +210,10 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G } return shutdownCb; } + + @Override + public boolean isShutdown() { + FutureCallback futureCallback = shutdown.get(); + return futureCallback != null && futureCallback.isDone(); + } } diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java index 70d266fdfa5..8074af7f64f 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java @@ -108,6 +108,8 @@ public class JettyHttpServer extends AbstractServerProvider { String RESPONSES_3XX = "http.status.3xx"; String RESPONSES_4XX = "http.status.4xx"; String RESPONSES_5XX = "http.status.5xx"; + String RESPONSES_401 = "http.status.401"; + String RESPONSES_403 = "http.status.403"; String STARTED_MILLIS = "serverStartedMillis"; @Deprecated String MANHATTAN_STARTED_MILLIS = "proc.uptime"; diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/MetricReporter.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/MetricReporter.java index 53f330bbc7e..4b01a475842 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/MetricReporter.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/MetricReporter.java @@ -5,7 +5,6 @@ import com.yahoo.jdisc.Metric; import com.yahoo.jdisc.Metric.Context; import com.yahoo.jdisc.http.server.jetty.JettyHttpServer.Metrics; -import org.jetbrains.annotations.Nullable; import java.util.concurrent.atomic.AtomicBoolean; @@ -16,7 +15,7 @@ import java.util.concurrent.atomic.AtomicBoolean; */ public class MetricReporter { private final Metric metric; - private final @Nullable Context context; + private final Context context; private final long requestStartTime; @@ -24,7 +23,7 @@ public class MetricReporter { private final AtomicBoolean firstSetOfTimeToFirstByte = new AtomicBoolean(true); - public MetricReporter(Metric metric, @Nullable Context context, long requestStartTime) { + public MetricReporter(Metric metric, Context context, long requestStartTime) { this.metric = metric; this.context = context; this.requestStartTime = requestStartTime; diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslContextFactoryProvider.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslContextFactoryProvider.java new file mode 100644 index 00000000000..f2d5d42ee2c --- /dev/null +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslContextFactoryProvider.java @@ -0,0 +1,108 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.ssl; + +import com.yahoo.config.InnerNode; +import com.yahoo.jdisc.http.ConnectorConfig; +import com.yahoo.jdisc.http.ssl.pem.PemSslKeyStore; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateUtils; +import org.eclipse.jetty.util.ssl.SslContextFactory; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.logging.Logger; + +/** + * JDisc's default implementation of {@link SslContextFactoryProvider} that uses the {@link ConnectorConfig} to construct a {@link SslContextFactory}. + * + * @author bjorncs + */ +public class DefaultSslContextFactoryProvider implements SslContextFactoryProvider { + + private final ConnectorConfig connectorConfig; + + public DefaultSslContextFactoryProvider(ConnectorConfig connectorConfig) { + validateConfig(connectorConfig.ssl()); + this.connectorConfig = connectorConfig; + } + + @Override + public SslContextFactory getInstance(String containerId, int port) { + ConnectorConfig.Ssl sslConfig = connectorConfig.ssl(); + if (!sslConfig.enabled()) throw new IllegalStateException(); + SslContextFactory factory = new JDiscSslContextFactory(); + + switch (sslConfig.clientAuth()) { + case NEED_AUTH: + factory.setNeedClientAuth(true); + break; + case WANT_AUTH: + factory.setWantClientAuth(true); + break; + } + + // NOTE: All ciphers matching ^TLS_RSA_.*$ are disabled by default in Jetty 9.4.12+ (https://github.com/eclipse/jetty.project/issues/2807) + // JDisc will allow these ciphers by default to support older clients (e.g. Java 8u60 and curl 7.29.0) + // Removing the exclusion will allow for the TLS_RSA variants that are not covered by other exclusions + String[] excludedCiphersWithoutTlsRsaExclusion = Arrays.stream(factory.getExcludeCipherSuites()) + .filter(cipher -> !cipher.equals("^TLS_RSA_.*$")) + .toArray(String[]::new); + factory.setExcludeCipherSuites(excludedCiphersWithoutTlsRsaExclusion); + + // Check if using new ssl syntax from services.xml + factory.setKeyStore(createKeystore(sslConfig)); + factory.setKeyStorePassword(""); + if (!sslConfig.caCertificateFile().isEmpty()) { + factory.setTrustStore(createTruststore(sslConfig)); + } + factory.setProtocol("TLS"); + return factory; + } + + private static void validateConfig(ConnectorConfig.Ssl config) { + if (!config.enabled()) return; + if (config.certificateFile().isEmpty()) { + throw new IllegalArgumentException("Missing certificate file."); + } + if (config.privateKeyFile().isEmpty()) { + throw new IllegalArgumentException("Missing private key file."); + } + + } + + private static KeyStore createTruststore(ConnectorConfig.Ssl sslConfig) { + List<X509Certificate> caCertificates = X509CertificateUtils.certificateListFromPem(readToString(sslConfig.caCertificateFile())); + KeyStoreBuilder truststoreBuilder = KeyStoreBuilder.withType(KeyStoreType.JKS); + for (int i = 0; i < caCertificates.size(); i++) { + truststoreBuilder.withCertificateEntry("entry-" + i, caCertificates.get(i)); + } + return truststoreBuilder.build(); + } + + private static KeyStore createKeystore(ConnectorConfig.Ssl sslConfig) { + PrivateKey privateKey = KeyUtils.fromPemEncodedPrivateKey(readToString(sslConfig.privateKeyFile())); + List<X509Certificate> certificates = X509CertificateUtils.certificateListFromPem(readToString(sslConfig.certificateFile())); + return KeyStoreBuilder.withType(KeyStoreType.JKS).withKeyEntry("default", privateKey, certificates).build(); + } + + private static String readToString(String filename) { + try { + return new String(Files.readAllBytes(Paths.get(filename))); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + +} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslKeyStoreConfigurator.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslKeyStoreConfigurator.java deleted file mode 100644 index 1cf8997b465..00000000000 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslKeyStoreConfigurator.java +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http.ssl; - -import com.google.inject.Inject; -import com.yahoo.jdisc.http.ConnectorConfig; -import com.yahoo.jdisc.http.ssl.pem.PemSslKeyStore; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.security.KeyStore; -import java.util.logging.Logger; - -/** - * @author bjorncs - */ -public class DefaultSslKeyStoreConfigurator implements SslKeyStoreConfigurator { - - private static final Logger log = Logger.getLogger(DefaultSslKeyStoreConfigurator.class.getName()); - - @SuppressWarnings("deprecation") - private final com.yahoo.jdisc.http.SecretStore secretStore; - private final ConnectorConfig.Ssl config; - - @Inject - @SuppressWarnings("deprecation") - public DefaultSslKeyStoreConfigurator(ConnectorConfig config, com.yahoo.jdisc.http.SecretStore secretStore) { - validateConfig(config.ssl()); - this.secretStore = secretStore; - this.config = config.ssl(); - } - - private static void validateConfig(ConnectorConfig.Ssl config) { - if (!config.enabled()) return; - switch (config.keyStoreType()) { - case JKS: - validateJksConfig(config); - break; - case PEM: - validatePemConfig(config); - break; - } - } - - @Override - public void configure(SslKeyStoreContext context) { - if (!config.enabled()) return; - switch (config.keyStoreType()) { - case JKS: - context.updateKeyStore(config.keyStorePath(), "JKS", secretStore.getSecret(config.keyDbKey())); - break; - case PEM: - context.updateKeyStore(createPemKeyStore(config.pemKeyStore())); - break; - } - } - - private static void validateJksConfig(ConnectorConfig.Ssl ssl) { - if (!ssl.pemKeyStore().keyPath().isEmpty() || ! ssl.pemKeyStore().certificatePath().isEmpty()) { - throw new IllegalArgumentException("pemKeyStore attributes can not be set when keyStoreType is JKS."); - } - if (ssl.keyDbKey().isEmpty()) { - throw new IllegalArgumentException("Missing password for JKS keystore"); - } - } - - private static void validatePemConfig(ConnectorConfig.Ssl ssl) { - if (! ssl.keyStorePath().isEmpty()) { - throw new IllegalArgumentException("keyStorePath can not be set when keyStoreType is PEM"); - } - if (!ssl.keyDbKey().isEmpty()) { - // TODO Make an error once there are separate passwords for truststore and keystore - log.warning("Encrypted PEM key stores are not supported. Password is only applied to truststore"); - } - if (ssl.pemKeyStore().certificatePath().isEmpty()) { - throw new IllegalArgumentException("Missing certificate path."); - } - if (ssl.pemKeyStore().keyPath().isEmpty()) { - throw new IllegalArgumentException("Missing key path."); - } - } - - private static KeyStore createPemKeyStore(ConnectorConfig.Ssl.PemKeyStore pemKeyStore) { - try { - Path certificatePath = Paths.get(pemKeyStore.certificatePath()); - Path keyPath = Paths.get(pemKeyStore.keyPath()); - return new PemSslKeyStore(certificatePath, keyPath).loadJavaKeyStore(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } catch (Exception e) { - throw new RuntimeException("Failed setting up key store for " + pemKeyStore.keyPath() + ", " + pemKeyStore.certificatePath(), e); - } - } - -} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslKeyStoreContext.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslKeyStoreContext.java deleted file mode 100644 index 44a9c606576..00000000000 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslKeyStoreContext.java +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http.ssl; - -import org.eclipse.jetty.util.ssl.SslContextFactory; - -import java.security.KeyStore; -import java.util.function.Consumer; - -/** - * @author bjorncs - */ -public class DefaultSslKeyStoreContext implements SslKeyStoreContext { - - private final SslContextFactory sslContextFactory; - - public DefaultSslKeyStoreContext(SslContextFactory sslContextFactory) { - this.sslContextFactory = sslContextFactory; - } - - @Override - public void updateKeyStore(KeyStore keyStore) { - updateKeyStore(keyStore, null); - } - - @Override - public void updateKeyStore(KeyStore keyStore, String password) { - updateKeyStore(sslContextFactory -> { - sslContextFactory.setKeyStore(keyStore); - if (password != null) { - sslContextFactory.setKeyStorePassword(password); - } - }); - } - - @Override - public void updateKeyStore(String keyStorePath, String keyStoreType, String keyStorePassword) { - updateKeyStore(sslContextFactory -> { - sslContextFactory.setKeyStorePath(keyStorePath); - sslContextFactory.setKeyStoreType(keyStoreType); - sslContextFactory.setKeyStorePassword(keyStorePassword); - }); - } - - private void updateKeyStore(Consumer<SslContextFactory> reloader) { - try { - sslContextFactory.reload(reloader); - } catch (Exception e) { - throw new RuntimeException("Could not update keystore: " + e.getMessage(), e); - } - } -} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslTrustStoreConfigurator.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslTrustStoreConfigurator.java deleted file mode 100644 index 5a8c399e6ba..00000000000 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslTrustStoreConfigurator.java +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http.ssl; - -import com.google.inject.Inject; -import com.yahoo.jdisc.http.ConnectorConfig; - -/** - * @author bjorncs - */ -public class DefaultSslTrustStoreConfigurator implements SslTrustStoreConfigurator { - - @SuppressWarnings("deprecation") - private final com.yahoo.jdisc.http.SecretStore secretStore; - private final ConnectorConfig.Ssl config; - - @Inject - @SuppressWarnings("deprecation") - public DefaultSslTrustStoreConfigurator(ConnectorConfig config, com.yahoo.jdisc.http.SecretStore secretStore) { - validateConfig(config.ssl()); - this.secretStore = secretStore; - this.config = config.ssl(); - } - - @Override - public void configure(SslTrustStoreContext context) { - if (!config.enabled()) return; - String keyDbPassword = config.keyDbKey(); - if (!config.trustStorePath().isEmpty()) { - String password = config.useTrustStorePassword() ? secretStore.getSecret(keyDbPassword) : null; - context.updateTrustStore(config.trustStorePath(), config.trustStoreType().toString(), password); - } - } - - private static void validateConfig(ConnectorConfig.Ssl config) { - if (!config.enabled()) return; - if (!config.trustStorePath().isEmpty() && config.useTrustStorePassword() && config.keyDbKey().isEmpty()) { - throw new IllegalArgumentException("Missing password for JKS truststore"); - } - } - -} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslTrustStoreContext.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslTrustStoreContext.java deleted file mode 100644 index c2d91cca3ea..00000000000 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/DefaultSslTrustStoreContext.java +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http.ssl; - -import org.eclipse.jetty.util.ssl.SslContextFactory; - -import java.security.KeyStore; -import java.util.function.Consumer; - -/** - * @author bjorncs - */ -public class DefaultSslTrustStoreContext implements SslTrustStoreContext { - - private final SslContextFactory sslContextFactory; - - public DefaultSslTrustStoreContext(SslContextFactory sslContextFactory) { - this.sslContextFactory = sslContextFactory; - } - - @Override - public void updateTrustStore(KeyStore trustStore) { - updateTrustStore(trustStore, null); - } - - @Override - public void updateTrustStore(KeyStore trustStore, String password) { - updateTrustStore(sslContextFactory -> { - sslContextFactory.setTrustStore(trustStore); - if (password != null) { - sslContextFactory.setTrustStorePassword(password); - } - }); - } - - @Override - public void updateTrustStore(String trustStorePath, String trustStoreType, String trustStorePassword) { - updateTrustStore(sslContextFactory -> { - sslContextFactory.setTrustStorePath(trustStorePath); - sslContextFactory.setTrustStoreType(trustStoreType); - if (trustStorePassword != null) { - sslContextFactory.setTrustStorePassword(trustStorePassword); - } - }); - } - - private void updateTrustStore(Consumer<SslContextFactory> reloader) { - try { - sslContextFactory.reload(reloader); - } catch (Exception e) { - throw new RuntimeException("Could not update truststore: " + e.getMessage(), e); - } - } - -} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscSslContextFactory.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/JDiscSslContextFactory.java index 81a6a0c8048..dcd9435334d 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscSslContextFactory.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/JDiscSslContextFactory.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http.server.jetty; +package com.yahoo.jdisc.http.ssl; import org.eclipse.jetty.util.resource.Resource; import org.eclipse.jetty.util.security.CertificateUtils; diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/LegacySslContextFactoryProvider.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/LegacySslContextFactoryProvider.java new file mode 100644 index 00000000000..5b090824e6a --- /dev/null +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/LegacySslContextFactoryProvider.java @@ -0,0 +1,163 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.ssl; + +import com.yahoo.config.InnerNode; +import com.yahoo.jdisc.http.ConnectorConfig; +import com.yahoo.jdisc.http.ssl.pem.PemSslKeyStore; +import org.eclipse.jetty.util.ssl.SslContextFactory; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyStore; +import java.util.Arrays; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.logging.Logger; + +/** + * A implementation of {@link SslContextFactoryProvider} to be injected into non-ssl connectors or connectors using legacy ssl config + * + * @author bjorncs + */ +// TODO Vespa 7: Remove legacy ssl config +public class LegacySslContextFactoryProvider implements SslContextFactoryProvider { + private static final Logger log = Logger.getLogger(LegacySslContextFactoryProvider.class.getName()); + + private final ConnectorConfig connectorConfig; + @SuppressWarnings("deprecation") + private final com.yahoo.jdisc.http.SecretStore secretStore; + + public LegacySslContextFactoryProvider(ConnectorConfig connectorConfig, + @SuppressWarnings("deprecation") com.yahoo.jdisc.http.SecretStore secretStore) { + validateConfig(connectorConfig.ssl()); + this.connectorConfig = connectorConfig; + this.secretStore = secretStore; + } + + @Override + public SslContextFactory getInstance(String containerId, int port) { + ConnectorConfig.Ssl sslConfig = connectorConfig.ssl(); + if (!sslConfig.enabled()) throw new IllegalStateException(); + SslContextFactory factory = new JDiscSslContextFactory(); + + switch (sslConfig.clientAuth()) { + case NEED_AUTH: + factory.setNeedClientAuth(true); + break; + case WANT_AUTH: + factory.setWantClientAuth(true); + break; + } + + // NOTE: All ciphers matching ^TLS_RSA_.*$ are disabled by default in Jetty 9.4.12+ (https://github.com/eclipse/jetty.project/issues/2807) + // JDisc will allow these ciphers by default to support older clients (e.g. Java 8u60 and curl 7.29.0) + // Removing the exclusion will allow for the TLS_RSA variants that are not covered by other exclusions + String[] excludedCiphersWithoutTlsRsaExclusion = Arrays.stream(factory.getExcludeCipherSuites()) + .filter(cipher -> !cipher.equals("^TLS_RSA_.*$")) + .toArray(String[]::new); + factory.setExcludeCipherSuites(excludedCiphersWithoutTlsRsaExclusion); + + switch (sslConfig.keyStoreType()) { + case JKS: + factory.setKeyStorePath(sslConfig.keyStorePath()); + factory.setKeyStoreType("JKS"); + factory.setKeyStorePassword(secretStore.getSecret(sslConfig.keyDbKey())); + break; + case PEM: + factory.setKeyStorePath(sslConfig.keyStorePath()); + factory.setKeyStore(createPemKeyStore(sslConfig.pemKeyStore())); + break; + } + + if (!sslConfig.trustStorePath().isEmpty()) { + factory.setTrustStorePath(sslConfig.trustStorePath()); + factory.setTrustStoreType(sslConfig.trustStoreType().toString()); + if (sslConfig.useTrustStorePassword()) { + factory.setTrustStorePassword(secretStore.getSecret(sslConfig.keyDbKey())); + } + } + + if (!sslConfig.prng().isEmpty()) { + factory.setSecureRandomAlgorithm(sslConfig.prng()); + } + + setStringArrayParameter( + factory, sslConfig.excludeProtocol(), ConnectorConfig.Ssl.ExcludeProtocol::name, SslContextFactory::setExcludeProtocols); + setStringArrayParameter( + factory, sslConfig.includeProtocol(), ConnectorConfig.Ssl.IncludeProtocol::name, SslContextFactory::setIncludeProtocols); + setStringArrayParameter( + factory, sslConfig.excludeCipherSuite(), ConnectorConfig.Ssl.ExcludeCipherSuite::name, SslContextFactory::setExcludeCipherSuites); + setStringArrayParameter( + factory, sslConfig.includeCipherSuite(), ConnectorConfig.Ssl.IncludeCipherSuite::name, SslContextFactory::setIncludeCipherSuites); + + factory.setKeyManagerFactoryAlgorithm(sslConfig.sslKeyManagerFactoryAlgorithm()); + factory.setProtocol(sslConfig.protocol()); + + return factory; + } + + private static void validateConfig(ConnectorConfig.Ssl config) { + if (!config.enabled()) return; + switch (config.keyStoreType()) { + case JKS: + validateJksConfig(config); + break; + case PEM: + validatePemConfig(config); + break; + } + if (!config.trustStorePath().isEmpty() && config.useTrustStorePassword() && config.keyDbKey().isEmpty()) { + throw new IllegalArgumentException("Missing password for JKS truststore"); + } + } + + private static void validateJksConfig(ConnectorConfig.Ssl ssl) { + if (!ssl.pemKeyStore().keyPath().isEmpty() || ! ssl.pemKeyStore().certificatePath().isEmpty()) { + throw new IllegalArgumentException("pemKeyStore attributes can not be set when keyStoreType is JKS."); + } + if (ssl.keyDbKey().isEmpty()) { + throw new IllegalArgumentException("Missing password for JKS keystore"); + } + } + + private static void validatePemConfig(ConnectorConfig.Ssl ssl) { + if (! ssl.keyStorePath().isEmpty()) { + throw new IllegalArgumentException("keyStorePath can not be set when keyStoreType is PEM"); + } + if (!ssl.keyDbKey().isEmpty()) { + log.warning("Encrypted PEM key stores are not supported. Password is only applied to truststore"); + } + if (ssl.pemKeyStore().certificatePath().isEmpty()) { + throw new IllegalArgumentException("Missing certificate path."); + } + if (ssl.pemKeyStore().keyPath().isEmpty()) { + throw new IllegalArgumentException("Missing key path."); + } + } + + private static KeyStore createPemKeyStore(ConnectorConfig.Ssl.PemKeyStore pemKeyStore) { + try { + Path certificatePath = Paths.get(pemKeyStore.certificatePath()); + Path keyPath = Paths.get(pemKeyStore.keyPath()); + return new PemSslKeyStore(certificatePath, keyPath).loadJavaKeyStore(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } catch (Exception e) { + throw new RuntimeException("Failed setting up key store for " + pemKeyStore.keyPath() + ", " + pemKeyStore.certificatePath(), e); + } + } + + private static <T extends InnerNode> void setStringArrayParameter(SslContextFactory sslContextFactory, + List<T> configValues, + Function<T, String> nameProperty, + BiConsumer<SslContextFactory, String[]> setter) { + if (!configValues.isEmpty()) { + String[] nameArray = configValues.stream().map(nameProperty).toArray(String[]::new); + setter.accept(sslContextFactory, nameArray); + } + } + +} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslContextFactoryProvider.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslContextFactoryProvider.java new file mode 100644 index 00000000000..37916fd5734 --- /dev/null +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslContextFactoryProvider.java @@ -0,0 +1,20 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.ssl; + +import org.eclipse.jetty.util.ssl.SslContextFactory; + +/** + * A provider that is used to configure SSL connectors in JDisc + * + * @author bjorncs + */ +public interface SslContextFactoryProvider { + + /** + * This method is called once for each SSL connector. + * + * @return returns an instance of {@link SslContextFactory} for a given JDisc http server + */ + SslContextFactory getInstance(String containerId, int port); + +} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslKeyStoreConfigurator.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslKeyStoreConfigurator.java deleted file mode 100644 index 619f4a636ed..00000000000 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslKeyStoreConfigurator.java +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http.ssl; - -/** - * An interface for an component that can configure an {@link SslKeyStoreContext}. The implementor can assume that - * the {@link SslKeyStoreContext} instance is thread-safe and be updated at any time - * during and after the call to{@link #configure(SslKeyStoreContext)}. - * Modifying the {@link SslKeyStoreContext} instance will trigger a hot reload of the keystore in JDisc. - * - * @author bjorncs - */ -public interface SslKeyStoreConfigurator { - void configure(SslKeyStoreContext context); -} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslKeyStoreContext.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslKeyStoreContext.java deleted file mode 100644 index 2a25f6d78b5..00000000000 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslKeyStoreContext.java +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http.ssl; - -import java.security.KeyStore; - -/** - * An interface to update the keystore in JDisc. Any update will trigger a hot reload and new connections will - * immediately see the new certificate chain. - * - * @author bjorncs - */ -public interface SslKeyStoreContext { - void updateKeyStore(KeyStore keyStore); - void updateKeyStore(KeyStore keyStore, String password); - void updateKeyStore(String keyStorePath, String keyStoreType, String keyStorePassword); -} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslTrustStoreConfigurator.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslTrustStoreConfigurator.java deleted file mode 100644 index de1119a5275..00000000000 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslTrustStoreConfigurator.java +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http.ssl; - -/** - * An interface for an component that can configure an {@link SslTrustStoreContext}. The implementor can assume that - * the {@link SslTrustStoreContext} instance is thread-safe and be updated at any time - * during and after the call to{@link #configure(SslTrustStoreContext)}. - * Modifying the {@link SslKeyStoreContext} instance will trigger a hot reload of the truststore in JDisc. - * - * @author bjorncs - */ -public interface SslTrustStoreConfigurator { - void configure(SslTrustStoreContext context); -} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslTrustStoreContext.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslTrustStoreContext.java deleted file mode 100644 index fc8cf397b24..00000000000 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/SslTrustStoreContext.java +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http.ssl; - -import java.security.KeyStore; - -/** - * An interface to update the truststore in JDisc. Any update will trigger a hot reload and new connections will - * authenticated using the update truststore. - * - * @author bjorncs - */ -public interface SslTrustStoreContext { - void updateTrustStore(KeyStore trustStore); - void updateTrustStore(KeyStore trustStore, String password); - void updateTrustStore(String trustStorePath, String trustStoreType, String trustStorePassword); -} diff --git a/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.connector.def b/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.connector.def index f0673e240c7..157ffabdd63 100644 --- a/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.connector.def +++ b/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.connector.def @@ -25,8 +25,8 @@ acceptQueueSize int default=0 # Whether the server socket reuses addresses. reuseAddress bool default=true -# TODO Vespa 7: Remove soLingerTime - Jetty no longer support it -# DEPRECATED The linger time in seconds. Use -1.0 to disable. +# TODO Vespa 7: Remove soLingerTime - Jetty no longer support it. +# DEPRECATED No longer in use soLingerTime double default=-1.0 # The maximum idle time for a connection, which roughly translates to the Socket.setSoTimeout(int). @@ -44,6 +44,23 @@ tcpNoDelay bool default=true # Whether to enable SSL for this connector. ssl.enabled bool default=false +# File with private key in PEM format +ssl.privateKeyFile string default="" + +# File with certificate in PEM format +ssl.certificateFile string default="" + +# with trusted CA certificates in PEM format. Used to verify clients +ssl.caCertificateFile string default="" + +# Client authentication mode. See SSLEngine.getNeedClientAuth()/getWantClientAuth() for details. +ssl.clientAuth enum { DISABLED, WANT_AUTH, NEED_AUTH } default=DISABLED + + +######################################################################################### +# Config below is deprecated. Do not use +######################################################################################### + # The name of the key to the password to the key store if in the secret store, if JKS is used. # Must be empty with PEM # By default this is also used to look up the password to the trust store. @@ -89,11 +106,9 @@ ssl.sslKeyManagerFactoryAlgorithm string default="SunX509" # The SSL protocol passed to SSLContext.getInstance() ssl.protocol string default="TLS" -# Client authentication mode. See SSLEngine.getNeedClientAuth()/getWantClientAuth() for details. -ssl.clientAuth enum { DISABLED, WANT_AUTH, NEED_AUTH } default=DISABLED - # The SecureRandom implementation passed to SSLEngine.init() # Java have a default pseudo-random number generator (PRNG) for crypto operations. This default may have performance # issues on some platform (e.g. NativePRNG in Linux utilizes a global lock). Changing the generator to SHA1PRNG may # improve performance. Set value to empty string to use the default generator. ssl.prng string default="" + diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/JksKeyStore.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/JksKeyStore.java deleted file mode 100644 index 1c7a917c688..00000000000 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/JksKeyStore.java +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http; - -import java.io.InputStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.security.KeyStore; - -/** - * @author Tony Vaagenes - * @author bjorncs - */ -public class JksKeyStore { - - private static final String KEY_STORE_TYPE = "JKS"; - - private final Path keyStoreFile; - private final String keyStorePassword; - - public JksKeyStore(Path keyStoreFile) { - this(keyStoreFile, null); - } - - public JksKeyStore(Path keyStoreFile, String keyStorePassword) { - this.keyStoreFile = keyStoreFile; - this.keyStorePassword = keyStorePassword; - } - - public String getKeyStorePassword() { - return keyStorePassword; - } - - public KeyStore loadJavaKeyStore() throws Exception { - try(InputStream stream = Files.newInputStream(keyStoreFile)) { - KeyStore keystore = KeyStore.getInstance(KEY_STORE_TYPE); - keystore.load(stream, keyStorePassword != null ? keyStorePassword.toCharArray() : null); - return keystore; - } - } - -} diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/SslContextFactory.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/SslContextFactory.java deleted file mode 100644 index d86516df453..00000000000 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/SslContextFactory.java +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http; - -import javax.net.ssl.KeyManagerFactory; -import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManagerFactory; -import java.util.logging.Level; -import java.util.logging.Logger; - -/** - * @author Charles Kim - */ -public class SslContextFactory { - - private static final Logger log = Logger.getLogger(SslContextFactory.class.getName()); - private static final String DEFAULT_ALGORITHM = "SunX509"; - private static final String DEFAULT_PROTOCOL = "TLS"; - private final SSLContext sslContext; - - private SslContextFactory(SSLContext sslContext) { - this.sslContext = sslContext; - } - - public SSLContext getServerSSLContext() { - return this.sslContext; - } - - public static SslContextFactory newInstanceFromTrustStore(JksKeyStore trustStore) { - return newInstance(DEFAULT_ALGORITHM, DEFAULT_PROTOCOL, null, trustStore); - } - - public static SslContextFactory newInstance(JksKeyStore trustStore, JksKeyStore keyStore) { - return newInstance(DEFAULT_ALGORITHM, DEFAULT_PROTOCOL, keyStore, trustStore); - } - - public static SslContextFactory newInstance(String sslAlgorithm, String sslProtocol, - JksKeyStore keyStore, JksKeyStore trustStore) { - log.fine("Configuring SSLContext..."); - log.fine("Using " + sslAlgorithm + " algorithm."); - try { - SSLContext sslContext = SSLContext.getInstance(sslProtocol); - sslContext.init( - keyStore == null ? null : getKeyManagers(keyStore, sslAlgorithm), - trustStore == null ? null : getTrustManagers(trustStore, sslAlgorithm), - null); - return new SslContextFactory(sslContext); - } catch (Exception e) { - log.log(Level.SEVERE, "Got exception creating SSLContext.", e); - throw new RuntimeException(e); - } - } - - /** - * Used for the key store, which contains the SSL cert and private key. - */ - public static javax.net.ssl.KeyManager[] getKeyManagers(JksKeyStore keyStore, - String sslAlgorithm) throws Exception { - - KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(sslAlgorithm); - String keyStorePassword = keyStore.getKeyStorePassword(); - keyManagerFactory.init( - keyStore.loadJavaKeyStore(), - keyStorePassword != null ? keyStorePassword.toCharArray() : null); - log.fine("KeyManagerFactory initialized with keystore"); - return keyManagerFactory.getKeyManagers(); - } - - /** - * Used for the trust store, which contains certificates from other parties that you expect to communicate with, - * or from Certificate Authorities that you trust to identify other parties. - */ - public static javax.net.ssl.TrustManager[] getTrustManagers(JksKeyStore trustStore, - String sslAlgorithm) - throws Exception { - - TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(sslAlgorithm); - trustManagerFactory.init(trustStore.loadJavaKeyStore()); - log.fine("TrustManagerFactory initialized with truststore."); - return trustManagerFactory.getTrustManagers(); - } - -} diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityRequestFilterChainTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityRequestFilterChainTest.java new file mode 100644 index 00000000000..2e072f29039 --- /dev/null +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityRequestFilterChainTest.java @@ -0,0 +1,145 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.filter; + +import com.yahoo.jdisc.AbstractResource; +import com.yahoo.jdisc.Response; +import com.yahoo.jdisc.handler.CompletionHandler; +import com.yahoo.jdisc.handler.ContentChannel; +import com.yahoo.jdisc.handler.ResponseDispatch; +import com.yahoo.jdisc.handler.ResponseHandler; +import com.yahoo.jdisc.http.HttpRequest; +import com.yahoo.jdisc.test.TestDriver; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import static org.testng.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class SecurityRequestFilterChainTest { + + + private static HttpRequest newRequest(URI uri, HttpRequest.Method method, HttpRequest.Version version) { + InetSocketAddress address = new InetSocketAddress("java.corp.yahoo.com", 69); + TestDriver driver = TestDriver.newSimpleApplicationInstanceWithoutOsgi(); + driver.activateContainer(driver.newContainerBuilder()); + HttpRequest request = HttpRequest.newServerRequest(driver, uri, method, version, address); + request.release(); + Assert.assertTrue(driver.close()); + return request; + } + + @Test + public void testFilterChainConstruction() { + SecurityRequestFilterChain chain = (SecurityRequestFilterChain)SecurityRequestFilterChain.newInstance(); + assertEquals(chain.getFilters().size(),0); + + List<SecurityRequestFilter> requestFilters = new ArrayList<SecurityRequestFilter>(); + chain = (SecurityRequestFilterChain)SecurityRequestFilterChain.newInstance(); + + chain = (SecurityRequestFilterChain)SecurityRequestFilterChain.newInstance(new RequestHeaderFilter("abc", "xyz"), + new RequestHeaderFilter("pqr", "def")); + + assertEquals(chain instanceof SecurityRequestFilterChain, true); + } + + + @Test + public void testFilterChainRun() { + RequestFilter chain = SecurityRequestFilterChain.newInstance(new RequestHeaderFilter("abc", "xyz"), + new RequestHeaderFilter("pqr", "def")); + + assertEquals(chain instanceof SecurityRequestFilterChain, true); + ResponseHandler handler = newResponseHandler(); + HttpRequest request = newRequest(URI.create("http://test/test"), HttpRequest.Method.GET, HttpRequest.Version.HTTP_1_1); + chain.filter(request, handler); + Assert.assertTrue(request.headers().contains("abc", "xyz")); + Assert.assertTrue(request.headers().contains("pqr", "def")); + } + + @Test + public void testFilterChainResponds() { + RequestFilter chain = SecurityRequestFilterChain.newInstance( + new MyFilter(), + new RequestHeaderFilter("abc", "xyz"), + new RequestHeaderFilter("pqr", "def")); + + assertEquals(chain instanceof SecurityRequestFilterChain, true); + ResponseHandler handler = newResponseHandler(); + HttpRequest request = newRequest(URI.create("http://test/test"), HttpRequest.Method.GET, HttpRequest.Version.HTTP_1_1); + chain.filter(request, handler); + Response response = getResponse(handler); + Assert.assertNotNull(response); + Assert.assertTrue(!request.headers().contains("abc", "xyz")); + Assert.assertTrue(!request.headers().contains("pqr", "def")); + } + + private class RequestHeaderFilter extends AbstractResource implements SecurityRequestFilter { + + private final String key; + private final String val; + + public RequestHeaderFilter(String key, String val) { + this.key = key; + this.val = val; + } + + @Override + public void filter(DiscFilterRequest request, ResponseHandler handler) { + request.setHeaders(key, val); + } + } + + private class MyFilter extends AbstractResource implements SecurityRequestFilter { + + @Override + public void filter(DiscFilterRequest request, ResponseHandler handler) { + ResponseDispatch.newInstance(Response.Status.FORBIDDEN).dispatch(handler); + } + } + + private static ResponseHandler newResponseHandler() { + return new NonWorkingResponseHandler(); + } + + private static Response getResponse(ResponseHandler handler) { + return ((NonWorkingResponseHandler) handler).getResponse(); + } + + private static class NonWorkingResponseHandler implements ResponseHandler { + + private Response response = null; + + @Override + public ContentChannel handleResponse(Response response) { + this.response = response; + return new NonWorkingContentChannel(); + } + + public Response getResponse() { + return response; + } + } + + private static class NonWorkingContentChannel implements ContentChannel { + + @Override + public void close(CompletionHandler handler) { + + } + + @Override + public void write(ByteBuffer buf, CompletionHandler handler) { + + } + + } + +} diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityResponseFilterChainTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityResponseFilterChainTest.java new file mode 100644 index 00000000000..b38ca240a78 --- /dev/null +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityResponseFilterChainTest.java @@ -0,0 +1,75 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.filter; + +import com.yahoo.jdisc.AbstractResource; +import com.yahoo.jdisc.Response; +import com.yahoo.jdisc.http.HttpRequest; +import com.yahoo.jdisc.http.HttpResponse; +import com.yahoo.jdisc.test.TestDriver; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.net.InetSocketAddress; +import java.net.URI; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +/** + * @author bjorncs + */ +public class SecurityResponseFilterChainTest { + private static HttpRequest newRequest(URI uri, HttpRequest.Method method, HttpRequest.Version version) { + InetSocketAddress address = new InetSocketAddress("java.corp.yahoo.com", 69); + TestDriver driver = TestDriver.newSimpleApplicationInstanceWithoutOsgi(); + driver.activateContainer(driver.newContainerBuilder()); + HttpRequest request = HttpRequest.newServerRequest(driver, uri, method, version, address); + request.release(); + Assert.assertTrue(driver.close()); + return request; + } + + @Test + public void testFilterChainConstruction() { + SecurityResponseFilterChain chain = (SecurityResponseFilterChain)SecurityResponseFilterChain.newInstance(); + assertEquals(chain.getFilters().size(),0); + + chain = (SecurityResponseFilterChain)SecurityResponseFilterChain.newInstance(new ResponseHeaderFilter("abc", "xyz"), + new ResponseHeaderFilter("pqr", "def")); + + assertEquals(chain instanceof SecurityResponseFilterChain, true); + } + + @Test + public void testFilterChainRun() { + URI uri = URI.create("http://localhost:8080/echo"); + HttpRequest request = newRequest(uri, HttpRequest.Method.GET, HttpRequest.Version.HTTP_1_1); + Response response = HttpResponse.newInstance(Response.Status.OK); + + ResponseFilter chain = SecurityResponseFilterChain.newInstance(new ResponseHeaderFilter("abc", "xyz"), + new ResponseHeaderFilter("pqr", "def")); + chain.filter(response, null); + assertTrue(response.headers().contains("abc", "xyz")); + assertTrue(response.headers().contains("pqr", "def")); + } + + private class ResponseHeaderFilter extends AbstractResource implements SecurityResponseFilter { + + private final String key; + private final String val; + + public ResponseHeaderFilter(String key, String val) { + this.key = key; + this.val = val; + } + + @Override + public void filter(DiscFilterResponse response, RequestView request) { + response.setHeaders(key, val); + } + + } + + + +} diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/guiceModules/ConnectorFactoryRegistryModule.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/guiceModules/ConnectorFactoryRegistryModule.java index d1a78f33e8f..d204d633304 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/guiceModules/ConnectorFactoryRegistryModule.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/guiceModules/ConnectorFactoryRegistryModule.java @@ -11,8 +11,7 @@ import com.yahoo.jdisc.http.ConnectorConfig.Builder; import com.yahoo.jdisc.http.server.jetty.ConnectorFactory; import com.yahoo.jdisc.http.server.jetty.TestDrivers; -import com.yahoo.jdisc.http.ssl.DefaultSslKeyStoreConfigurator; -import com.yahoo.jdisc.http.ssl.DefaultSslTrustStoreConfigurator; +import com.yahoo.jdisc.http.ssl.DefaultSslContextFactoryProvider; /** * Guice module for test ConnectorFactories @@ -48,19 +47,7 @@ public class ConnectorFactoryRegistryModule implements Module { private static class StaticKeyDbConnectorFactory extends ConnectorFactory { public StaticKeyDbConnectorFactory(ConnectorConfig connectorConfig) { - super(connectorConfig, - new DefaultSslKeyStoreConfigurator(connectorConfig, new MockSecretStore()), - new DefaultSslTrustStoreConfigurator(connectorConfig, new MockSecretStore())); - } - - } - - @SuppressWarnings("deprecation") - private static final class MockSecretStore implements com.yahoo.jdisc.http.SecretStore { - - @Override - public String getSecret(String key) { - return TestDrivers.KEY_STORE_PASSWORD; + super(connectorConfig, new DefaultSslContextFactoryProvider(connectorConfig)); } } diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactoryTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactoryTest.java index 083be36043e..08a38d5e13b 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactoryTest.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactoryTest.java @@ -3,8 +3,7 @@ package com.yahoo.jdisc.http.server.jetty; import com.yahoo.jdisc.Metric; import com.yahoo.jdisc.http.ConnectorConfig; -import com.yahoo.jdisc.http.ssl.DefaultSslKeyStoreConfigurator; -import com.yahoo.jdisc.http.ssl.DefaultSslTrustStoreConfigurator; +import com.yahoo.jdisc.http.ssl.DefaultSslContextFactoryProvider; import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.handler.AbstractHandler; @@ -106,10 +105,7 @@ public class ConnectorFactoryTest { } private static ConnectorFactory createConnectorFactory(ConnectorConfig config) { - ThrowingSecretStore secretStore = new ThrowingSecretStore(); - return new ConnectorFactory(config, - new DefaultSslKeyStoreConfigurator(config, secretStore), - new DefaultSslTrustStoreConfigurator(config, secretStore)); + return new ConnectorFactory(config, new DefaultSslContextFactoryProvider(config)); } private static class HelloWorldHandler extends AbstractHandler { @@ -138,14 +134,4 @@ public class ConnectorFactoryTest { private static class DummyContext implements Metric.Context { } - @SuppressWarnings("deprecation") - private static final class ThrowingSecretStore implements com.yahoo.jdisc.http.SecretStore { - - @Override - public String getSecret(String key) { - throw new UnsupportedOperationException("A secret store is not available"); - } - - } - } diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java index e3d70fb5bd6..3c23a2b0937 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java @@ -65,6 +65,19 @@ public class HttpResponseStatisticsCollectorTest { } @Test + public void statistics_include_grouped_and_single_statuscodes() throws Exception { + testRequest(401, "GET"); + testRequest(404, "GET"); + testRequest(403, "GET"); + + Map<String, Map<String, Long>> stats = collector.takeStatisticsByMethod(); + assertThat(stats.get("GET").get(Metrics.RESPONSES_4XX), equalTo(3L)); + assertThat(stats.get("GET").get(Metrics.RESPONSES_401), equalTo(1L)); + assertThat(stats.get("GET").get(Metrics.RESPONSES_403), equalTo(1L)); + + } + + @Test public void retrieving_statistics_resets_the_counters() throws Exception { testRequest(200, "GET"); testRequest(200, "GET"); diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/TestDriver.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/TestDriver.java index 39b68fcf1f6..227b0b20f10 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/TestDriver.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/TestDriver.java @@ -1,20 +1,16 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jdisc.http.server.jetty; -import com.google.inject.Key; import com.google.inject.Module; import com.yahoo.jdisc.application.ContainerBuilder; import com.yahoo.jdisc.handler.RequestHandler; import com.yahoo.jdisc.http.ConnectorConfig; -import com.yahoo.jdisc.http.SslContextFactory; -import com.yahoo.jdisc.http.JksKeyStore; +import com.yahoo.security.SslContextBuilder; import javax.net.ssl.SSLContext; import java.io.IOException; import java.nio.file.Paths; -import static com.google.inject.name.Names.named; - /** * This class is based on the class by the same name in the jdisc_http_service module. * It provides functionality for setting up a jdisc container with an HTTP server and handlers. @@ -61,9 +57,7 @@ public class TestDriver { public SimpleHttpClient client() { return client; } - public SimpleHttpClient newClient() throws IOException { return newClient(false); } - - public SimpleHttpClient newClient(final boolean useCompression) throws IOException { + public SimpleHttpClient newClient(final boolean useCompression) { return new SimpleHttpClient(newSslContext(), server.getListenPort(), useCompression); } @@ -75,10 +69,10 @@ public class TestDriver { ConnectorConfig.Ssl sslConfig = builder.getInstance(ConnectorConfig.class).ssl(); if (!sslConfig.enabled()) return null; - JksKeyStore keyStore = new JksKeyStore( - Paths.get(sslConfig.keyStorePath()), - builder.getInstance(Key.get(String.class, named("keyStorePassword")))); - return SslContextFactory.newInstanceFromTrustStore(keyStore).getServerSSLContext(); + return new SslContextBuilder() + .withKeyStore(Paths.get(sslConfig.privateKeyFile()), Paths.get(sslConfig.certificateFile())) + .withTrustStore(Paths.get(sslConfig.caCertificateFile())) + .build(); } } diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/TestDrivers.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/TestDrivers.java index f4344545637..b7805328124 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/TestDrivers.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/TestDrivers.java @@ -17,15 +17,13 @@ import com.yahoo.jdisc.http.server.FilterBindings; import java.io.IOException; -import static com.google.inject.name.Names.named; - /** * @author Simon Thoresen Hult */ public class TestDrivers { - private static final String KEY_STORE = "src/test/resources/ssl_keystore_test.jks"; - public static final String KEY_STORE_PASSWORD = "secret"; + private static final String PRIVATE_KEY_FILE = "src/test/resources/pem/test.key"; + private static final String CERTIFICATE_FILE = "src/test/resources/pem/test.crt"; public static TestDriver newConfiguredInstance(final RequestHandler requestHandler, final ServerConfig.Builder serverConfig, @@ -59,18 +57,10 @@ public class TestDrivers { new ConnectorConfig.Builder() .ssl(new ConnectorConfig.Ssl.Builder() .enabled(true) - .keyDbKey("dummy-key-for-StaticKeyDbConnectorFactory.getPasswordFromKeydb") - .keyStorePath(KEY_STORE) - .trustStorePath(KEY_STORE)), - Modules.combine(new AbstractModule() { - - @Override - protected void configure() { - bind(String.class).annotatedWith(named("keyStorePassword")) - .toInstance(KEY_STORE_PASSWORD); - } - }, Modules.combine(guiceModules)) - )); + .privateKeyFile(PRIVATE_KEY_FILE) + .certificateFile(CERTIFICATE_FILE) + .caCertificateFile(CERTIFICATE_FILE)), + Modules.combine(guiceModules))); } private static Module newConfigModule( diff --git a/jdisc_http_service/src/test/resources/ssl_keystore_test.jks b/jdisc_http_service/src/test/resources/ssl_keystore_test.jks Binary files differdeleted file mode 100644 index 6dbb19b9692..00000000000 --- a/jdisc_http_service/src/test/resources/ssl_keystore_test.jks +++ /dev/null diff --git a/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java b/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java index d0f5de54b4f..bc3d1edda7c 100644 --- a/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java +++ b/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java @@ -32,7 +32,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider private final ResourceReference sessionReference; @Inject - public MbusServer(final CurrentContainer container, final ServerSession session) { + public MbusServer(CurrentContainer container, ServerSession session) { this.container = container; this.session = session; uri = URI.create("mbus://localhost/" + session.name()); @@ -60,7 +60,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider } @Override - public void handleMessage(final Message msg) { + public void handleMessage(Message msg) { if (!running.get()) { dispatchErrorReply(msg, ErrorCode.SESSION_BUSY, "Session temporarily closed."); return; @@ -73,7 +73,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider try { request = new MbusRequest(container, uri, msg); content = request.connect(new ServerResponseHandler(msg)); - } catch (final RuntimeException e) { + } catch (RuntimeException e) { dispatchErrorReply(msg, ErrorCode.APP_FATAL_ERROR, e.toString()); } finally { if (request != null) { @@ -89,8 +89,8 @@ public final class MbusServer extends AbstractResource implements ServerProvider return session.connectionSpec(); } - private void dispatchErrorReply(final Message msg, final int errCode, final String errMsg) { - final Reply reply = new EmptyReply(); + private void dispatchErrorReply(Message msg, int errCode, String errMsg) { + Reply reply = new EmptyReply(); reply.swapState(msg); reply.addError(new Error(errCode, errMsg)); session.sendReply(reply); @@ -100,20 +100,20 @@ public final class MbusServer extends AbstractResource implements ServerProvider final Message msg; - ServerResponseHandler(final Message msg) { + ServerResponseHandler(Message msg) { this.msg = msg; } @Override - public ContentChannel handleResponse(final Response response) { - final Reply reply; + public ContentChannel handleResponse(Response response) { + Reply reply; if (response instanceof MbusResponse) { reply = ((MbusResponse)response).getReply(); } else { reply = new EmptyReply(); reply.swapState(msg); } - final Error err = StatusCodes.toMbusError(response.getStatus()); + Error err = StatusCodes.toMbusError(response.getStatus()); if (err != null) { if (err.isFatal()) { if (!reply.hasFatalErrors()) { diff --git a/jrt/pom.xml b/jrt/pom.xml index cf3da2ab7ce..84578f9e04d 100644 --- a/jrt/pom.xml +++ b/jrt/pom.xml @@ -19,6 +19,11 @@ <scope>test</scope> </dependency> <dependency> + <groupId>org.bouncycastle</groupId> + <artifactId>bcpkix-jdk15on</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>annotations</artifactId> <version>${project.version}</version> diff --git a/jrt/src/com/yahoo/jrt/CryptoEngine.java b/jrt/src/com/yahoo/jrt/CryptoEngine.java index 9852d5a88a6..2ef936ec7ed 100644 --- a/jrt/src/com/yahoo/jrt/CryptoEngine.java +++ b/jrt/src/com/yahoo/jrt/CryptoEngine.java @@ -2,7 +2,10 @@ package com.yahoo.jrt; +import com.yahoo.security.tls.TransportSecurityOptions; + import java.nio.channels.SocketChannel; +import java.nio.file.Paths; /** @@ -13,5 +16,12 @@ import java.nio.channels.SocketChannel; **/ public interface CryptoEngine { public CryptoSocket createCryptoSocket(SocketChannel channel, boolean isServer); - static public CryptoEngine createDefault() { return new NullCryptoEngine(); } + static public CryptoEngine createDefault() { // TODO Move this logic to a dedicated factory class + String tlsConfigParameter = System.getenv("VESPA_TLS_CONFIG_FILE"); + if (tlsConfigParameter != null && !tlsConfigParameter.isEmpty()) { + return new TlsCryptoEngine(TransportSecurityOptions.fromJsonFile(Paths.get(tlsConfigParameter))); + } else { + return new NullCryptoEngine(); + } + } } diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java new file mode 100644 index 00000000000..b3daf5c296d --- /dev/null +++ b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java @@ -0,0 +1,48 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +import com.yahoo.security.SslContextBuilder; +import com.yahoo.security.X509CertificateUtils; +import com.yahoo.security.tls.TransportSecurityOptions; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.channels.SocketChannel; +import java.nio.file.Files; +import java.security.cert.X509Certificate; +import java.util.List; + +/** + * A {@link CryptoSocket} that creates {@link TlsCryptoSocket} instances. + * + * @author bjorncs + */ +public class TlsCryptoEngine implements CryptoEngine { + + private final SSLContext sslContext; + + public TlsCryptoEngine(SSLContext sslContext) { + this.sslContext = sslContext; + } + + public TlsCryptoEngine(TransportSecurityOptions options) { + this(createSslContext(options)); + } + + @Override + public TlsCryptoSocket createCryptoSocket(SocketChannel channel, boolean isServer) { + SSLEngine sslEngine = sslContext.createSSLEngine(); + sslEngine.setNeedClientAuth(true); + sslEngine.setUseClientMode(!isServer); + return new TlsCryptoSocket(channel, sslEngine); + } + + private static SSLContext createSslContext(TransportSecurityOptions options) { + return new SslContextBuilder() + .withTrustStore(options.getCaCertificatesFile()) + .withKeyStore(options.getPrivateKeyFile(), options.getCertificatesFile()) + .build(); + } +} diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java new file mode 100644 index 00000000000..3db54811f9e --- /dev/null +++ b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java @@ -0,0 +1,253 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SocketChannel; +import java.util.logging.Logger; + +import static javax.net.ssl.SSLEngineResult.*; + +/** + * A {@link CryptoSocket} using TLS ({@link SSLEngine}) + * + * @author bjorncs + */ +public class TlsCryptoSocket implements CryptoSocket { + + private static final ByteBuffer NULL_BUFFER = ByteBuffer.allocate(0); + + private static final Logger log = Logger.getLogger(TlsCryptoSocket.class.getName()); + + private enum HandshakeState { NOT_STARTED, NEED_READ, NEED_WRITE, COMPLETED } + + private final SocketChannel channel; + private final SSLEngine sslEngine; + private final Buffer wrapBuffer; + private final Buffer unwrapBuffer; + private int sessionPacketBufferSize; + private int sessionApplicationBufferSize; + private ByteBuffer handshakeDummyBuffer; + private HandshakeState handshakeState; + + public TlsCryptoSocket(SocketChannel channel, SSLEngine sslEngine) { + this.channel = channel; + this.sslEngine = sslEngine; + SSLSession nullSession = sslEngine.getSession(); + this.wrapBuffer = new Buffer(nullSession.getPacketBufferSize() * 2); + this.unwrapBuffer = new Buffer(nullSession.getPacketBufferSize() * 2); + // Note: Dummy buffer as unwrap requires a full size application buffer even though no application data is unwrapped + this.handshakeDummyBuffer = ByteBuffer.allocate(nullSession.getApplicationBufferSize()); + this.handshakeState = HandshakeState.NOT_STARTED; + } + + @Override + public SocketChannel channel() { + return channel; + } + + @Override + public HandshakeResult handshake() throws IOException { + HandshakeState newHandshakeState = processHandshakeState(this.handshakeState); + log.fine(() -> String.format("Handshake state '%s -> %s'", this.handshakeState, newHandshakeState)); + this.handshakeState = newHandshakeState; + return toHandshakeResult(newHandshakeState); + } + + private HandshakeState processHandshakeState(HandshakeState state) throws IOException { + switch (state) { + case NOT_STARTED: + sslEngine.beginHandshake(); + break; + case NEED_WRITE: + channelWrite(); + break; + case NEED_READ: + channelRead(); + break; + case COMPLETED: + return HandshakeState.COMPLETED; + default: + throw unhandledStateException(state); + } + + while (true) { + switch (sslEngine.getHandshakeStatus()) { + case NOT_HANDSHAKING: + if (wrapBuffer.bytes() > 0) return HandshakeState.NEED_WRITE; + sslEngine.setEnableSessionCreation(false); // disable renegotiation + handshakeDummyBuffer = null; + SSLSession session = sslEngine.getSession(); + sessionApplicationBufferSize = session.getApplicationBufferSize(); + sessionPacketBufferSize = session.getPacketBufferSize(); + return HandshakeState.COMPLETED; + case NEED_TASK: + sslEngine.getDelegatedTask().run(); + break; + case NEED_UNWRAP: + if (wrapBuffer.bytes() > 0) return HandshakeState.NEED_WRITE; + if (!handshakeUnwrap()) return HandshakeState.NEED_READ; + break; + case NEED_WRAP: + if (!handshakeWrap()) return HandshakeState.NEED_WRITE; + break; + default: + throw new IllegalStateException("Unexpected handshake status: " + sslEngine.getHandshakeStatus()); + } + } + } + + private static HandshakeResult toHandshakeResult(HandshakeState state) { + switch (state) { + case NEED_READ: + return HandshakeResult.NEED_READ; + case NEED_WRITE: + return HandshakeResult.NEED_WRITE; + case COMPLETED: + return HandshakeResult.DONE; + default: + throw unhandledStateException(state); + } + } + + @Override + public int getMinimumReadBufferSize() { + return sessionApplicationBufferSize; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + verifyHandshakeCompleted(); + int bytesUnwrapped = drain(dst); + if (bytesUnwrapped > 0) return bytesUnwrapped; + + int bytesRead = channelRead(); + if (bytesRead == 0) return 0; + return drain(dst); + } + + @Override + public int drain(ByteBuffer dst) throws IOException { + verifyHandshakeCompleted(); + int totalBytesUnwrapped = 0; + int bytesUnwrapped; + do { + bytesUnwrapped = applicationDataUnwrap(dst); + totalBytesUnwrapped += bytesUnwrapped; + } while (bytesUnwrapped > 0); + return totalBytesUnwrapped; + } + + @Override + public int write(ByteBuffer src) throws IOException { + if (flush() == FlushResult.NEED_WRITE) return 0; + int totalBytesWrapped = 0; + int bytesWrapped; + do { + bytesWrapped = applicationDataWrap(src); + totalBytesWrapped += bytesWrapped; + } while (bytesWrapped > 0 && wrapBuffer.bytes() < sessionPacketBufferSize); + return totalBytesWrapped; + } + + @Override + public FlushResult flush() throws IOException { + channelWrite(); + return wrapBuffer.bytes() > 0 ? FlushResult.NEED_WRITE : FlushResult.DONE; + } + + private boolean handshakeWrap() throws IOException { + SSLEngineResult result = sslEngineWrap(NULL_BUFFER); + switch (result.getStatus()) { + case OK: + return true; + case BUFFER_OVERFLOW: + return false; + default: + throw unexpectedStatusException(result.getStatus()); + } + } + + private int applicationDataWrap(ByteBuffer src) throws IOException { + SSLEngineResult result = sslEngineWrap(src); + if (result.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING) throw new SSLException("Renegotiation detected"); + switch (result.getStatus()) { + case OK: + return result.bytesConsumed(); + case BUFFER_OVERFLOW: + return 0; + default: + throw unexpectedStatusException(result.getStatus()); + } + } + + private SSLEngineResult sslEngineWrap(ByteBuffer src) throws IOException { + SSLEngineResult result = sslEngine.wrap(src, wrapBuffer.getWritable(sessionPacketBufferSize)); + if (result.getStatus() == Status.CLOSED) throw new ClosedChannelException(); + return result; + } + + private boolean handshakeUnwrap() throws IOException { + SSLEngineResult result = sslEngineUnwrap(handshakeDummyBuffer); + switch (result.getStatus()) { + case OK: + if (result.bytesProduced() > 0) throw new SSLException("Got application data in handshake unwrap"); + return true; + case BUFFER_UNDERFLOW: + return false; + default: + throw unexpectedStatusException(result.getStatus()); + } + } + + private int applicationDataUnwrap(ByteBuffer dst) throws IOException { + SSLEngineResult result = sslEngineUnwrap(dst); + if (result.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING) throw new SSLException("Renegotiation detected"); + switch (result.getStatus()) { + case OK: + return result.bytesProduced(); + case BUFFER_OVERFLOW: + case BUFFER_UNDERFLOW: + return 0; + default: + throw unexpectedStatusException(result.getStatus()); + } + } + + private SSLEngineResult sslEngineUnwrap(ByteBuffer dst) throws IOException { + SSLEngineResult result = sslEngine.unwrap(unwrapBuffer.getReadable(), dst); + if (result.getStatus() == Status.CLOSED) throw new ClosedChannelException(); + return result; + } + + // returns number of bytes read + private int channelRead() throws IOException { + int read = channel.read(unwrapBuffer.getWritable(sessionPacketBufferSize)); + if (read == -1) throw new ClosedChannelException(); + return read; + } + + // returns number of bytes written + private int channelWrite() throws IOException { + return channel.write(wrapBuffer.getReadable()); + } + + private static IllegalStateException unhandledStateException(HandshakeState state) { + return new IllegalStateException("Unhandled state: " + state); + } + + private static IllegalStateException unexpectedStatusException(Status status) { + return new IllegalStateException("Unexpected status: " + status); + } + + private void verifyHandshakeCompleted() throws SSLException { + if (handshakeState != HandshakeState.COMPLETED) + throw new SSLException("Handshake not completed: handshakeState=" + handshakeState); + } + +} diff --git a/jrt/tests/com/yahoo/jrt/CryptoUtils.java b/jrt/tests/com/yahoo/jrt/CryptoUtils.java new file mode 100644 index 00000000000..c3128e09bd3 --- /dev/null +++ b/jrt/tests/com/yahoo/jrt/CryptoUtils.java @@ -0,0 +1,43 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SslContextBuilder; +import com.yahoo.security.X509CertificateBuilder; + +import javax.net.ssl.SSLContext; +import javax.security.auth.x500.X500Principal; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.cert.X509Certificate; +import java.time.Instant; + +import static com.yahoo.security.KeyAlgorithm.RSA; +import static com.yahoo.security.KeyStoreType.PKCS12; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_RSA; +import static com.yahoo.security.X509CertificateBuilder.generateRandomSerialNumber; +import static java.time.Instant.EPOCH; +import static java.time.temporal.ChronoUnit.DAYS; + +/** + * @author bjorncs + */ +class CryptoUtils { + static SSLContext createTestSslContext() { + KeyPair keyPair = KeyUtils.generateKeypair(RSA); + + X509Certificate certificate = X509CertificateBuilder + .fromKeypair(keyPair, new X500Principal("CN=dummy"), EPOCH, Instant.now().plus(1, DAYS), SHA256_WITH_RSA, generateRandomSerialNumber()) + .build(); + + KeyStore trustStore = KeyStoreBuilder.withType(PKCS12) + .withCertificateEntry("self-signed", certificate) + .build(); + + return new SslContextBuilder() + .withTrustStore(trustStore) + .withKeyStore(keyPair.getPrivate(), certificate) + .build(); + } +} diff --git a/jrt/tests/com/yahoo/jrt/EchoTest.java b/jrt/tests/com/yahoo/jrt/EchoTest.java index 0523241354a..a91ac117f41 100644 --- a/jrt/tests/com/yahoo/jrt/EchoTest.java +++ b/jrt/tests/com/yahoo/jrt/EchoTest.java @@ -5,11 +5,11 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; -import org.junit.runners.Parameterized; - +import static com.yahoo.jrt.CryptoUtils.createTestSslContext; import static org.junit.Assert.assertTrue; @RunWith(Parameterized.class) @@ -22,8 +22,8 @@ public class EchoTest { Values refValues; @Parameter public CryptoEngine crypto; - @Parameters public static Object[] engines() { - return new Object[] { CryptoEngine.createDefault(), new XorCryptoEngine() }; + @Parameters(name = "{0}") public static Object[] engines() { + return new Object[] { CryptoEngine.createDefault(), new XorCryptoEngine(), new TlsCryptoEngine(createTestSslContext()) }; } @Before diff --git a/jrt/tests/com/yahoo/jrt/SessionTest.java b/jrt/tests/com/yahoo/jrt/SessionTest.java index 2f1a64538de..63d14601b6e 100644 --- a/jrt/tests/com/yahoo/jrt/SessionTest.java +++ b/jrt/tests/com/yahoo/jrt/SessionTest.java @@ -5,10 +5,11 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; -import org.junit.runners.Parameterized; +import static com.yahoo.jrt.CryptoUtils.createTestSslContext; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -17,8 +18,8 @@ import static org.junit.Assert.assertTrue; public class SessionTest implements SessionHandler { @Parameter public CryptoEngine crypto; - @Parameters public static Object[] engines() { - return new Object[] { CryptoEngine.createDefault(), new XorCryptoEngine() }; + @Parameters(name = "{0}") public static Object[] engines() { + return new Object[] { CryptoEngine.createDefault(), new XorCryptoEngine(), new TlsCryptoEngine(createTestSslContext()) }; } private static class Session { diff --git a/jrt_test/src/jrt-test/simpleserver/simpleserver.cpp b/jrt_test/src/jrt-test/simpleserver/simpleserver.cpp index ed7ff0e40bc..89d8cd881a8 100644 --- a/jrt_test/src/jrt-test/simpleserver/simpleserver.cpp +++ b/jrt_test/src/jrt-test/simpleserver/simpleserver.cpp @@ -14,19 +14,19 @@ public: { FRT_ReflectionBuilder rb(s); //--------------------------------------------------------------------- - rb.DefineMethod("inc", "i", "i", true, + rb.DefineMethod("inc", "i", "i", FRT_METHOD(Server::rpc_inc), this); rb.MethodDesc("Increase an integer value"); rb.ParamDesc("value", "initial value"); rb.ReturnDesc("result", "value + 1"); //--------------------------------------------------------------------- - rb.DefineMethod("blob", "x", "x", true, + rb.DefineMethod("blob", "x", "x", FRT_METHOD(Server::rpc_blob), this); rb.MethodDesc("Send a copy of a blob back to the client"); rb.ParamDesc("blob", "the original blob"); rb.ReturnDesc("blob", "a copy of the original blob"); //--------------------------------------------------------------------- - rb.DefineMethod("test", "iib", "i", true, + rb.DefineMethod("test", "iib", "i", FRT_METHOD(Server::rpc_test), this); rb.MethodDesc("Magic test method"); rb.ParamDesc("value", "the value"); @@ -76,15 +76,10 @@ int App::Main() { if (_argc < 2) { - printf("usage: %s <listenspec> [ddw]\n", _argv[0]); - printf(" ddw = disable direct write\n"); + printf("usage: %s <listenspec>\n", _argv[0]); return 1; } FRT_Supervisor orb; - if (_argc >= 3 && strcmp(_argv[2], "ddw") == 0) { - printf("(direct write disabled)\n"); - orb.GetTransport()->SetDirectWrite(false); - } Server server(&orb); orb.Listen(_argv[1]); orb.Main(); diff --git a/jrt_test/src/tests/mockup-invoke/mockup-server.cpp b/jrt_test/src/tests/mockup-invoke/mockup-server.cpp index 32c9bcc6c21..8456bee1e41 100644 --- a/jrt_test/src/tests/mockup-invoke/mockup-server.cpp +++ b/jrt_test/src/tests/mockup-invoke/mockup-server.cpp @@ -14,7 +14,7 @@ public: { FRT_ReflectionBuilder rb(s); //------------------------------------------------------------------- - rb.DefineMethod("concat", "ss", "s", true, + rb.DefineMethod("concat", "ss", "s", FRT_METHOD(MockupServer::RPC_concat), this); rb.MethodDesc("Concatenate two strings"); rb.ParamDesc("string1", "a string"); diff --git a/juniper/build/buildspec.xml b/juniper/build/buildspec.xml deleted file mode 100644 index 335140f192f..00000000000 --- a/juniper/build/buildspec.xml +++ /dev/null @@ -1,60 +0,0 @@ -<!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> -<BuildSpecification> - <Dependencies os="all" arch="all"> - <dep package="common/fastos" version="1.5.99.2"> - <!-- Use stl-port when compiling on platforms with gcc and on windows --> - <if os="linux,freebsd"> - <addfeature name="stl-port" /> - </if> - </dep> - <dep package="common/fastlib" version="1.6.4" /> - <dep package="3rdparty/libiconv" version="1.8" featureset="iconvconst" /> - </Dependencies> - - <PreBuild os="all" arch="all" nocopy="yes"> - <if os="XXaix"> - <!-- Iconv library has been renamed to fsiconv on aix to avoid --> - <!-- mix-ups with other installations. We must alter some --> - <!-- files to make the linking work. --> - <replace token="LIBDIR_ICONV" value="LIBDIR_FSICONV"> - <file path="Makefile" /> - <file path="configure.cfg" /> - </replace> - <replace token=" iconv" value=" fsiconv"> - <file path="src/test/fastos.project" /> - </replace> - </if> - <configure path=""> - <parameter value="--fastos-dir ${fbuild_install_dir}/fastos" /> - <parameter value="--fastlib-dir ${fbuild_install_dir}/fastlib" /> - <parameter value="--iconv-dir ${fbuild_install_dir}/iconv" /> - <parameter value="--install-dir ${fbuild_install_dir}/juniper" /> - <parameter value="--version ${fbuild_package_version}" /> - </configure> - <make path="" target="makefiles" /> - <make path="" target="depend" /> - </PreBuild> - - <Build os="unix" arch="all"> - <make path="" target="all" /> - </Build> - - <Build os="win32" arch="all"> - <make path="" target="all" /> - </Build> - - <PostBuild os="all" arch="all"> - </PostBuild> - - <Test os="all" arch="all"> - <make path="src/test" parameters="test" /> - </Test> - - <Install os="all" arch="all"> - <make path="" target="install" /> - </Install> - - <Dist os="all" arch="all"> - </Dist> - -</BuildSpecification> diff --git a/linguistics/src/main/java/com/yahoo/language/detect/Detection.java b/linguistics/src/main/java/com/yahoo/language/detect/Detection.java index c08bdc14cfb..4b816335154 100644 --- a/linguistics/src/main/java/com/yahoo/language/detect/Detection.java +++ b/linguistics/src/main/java/com/yahoo/language/detect/Detection.java @@ -7,7 +7,7 @@ import java.nio.charset.Charset; import java.nio.charset.UnsupportedCharsetException; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class Detection { diff --git a/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java b/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java index 12de309a2d3..7451a7f2c9c 100644 --- a/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java +++ b/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java @@ -4,8 +4,10 @@ import com.yahoo.language.process.Tokenizer; import com.yahoo.language.simple.SimpleLinguistics; public class OpenNlpLinguistics extends SimpleLinguistics { + @Override public Tokenizer getTokenizer() { return new OpenNlpTokenizer(getNormalizer(), getTransformer()); } + } diff --git a/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java b/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java index 2b31f95675b..0503ac61df1 100644 --- a/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java +++ b/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java @@ -37,36 +37,49 @@ import java.util.Locale; * @author bjorncs */ public class SimpleDetector implements Detector { - static private TextObjectFactory textObjectFactory; - static private LanguageDetector languageDetector; - - static { - // origin: https://github.com/optimaize/language-detector - //load all languages: - List<LanguageProfile> languageProfiles; - try { - languageProfiles = new LanguageProfileReader().readAllBuiltIn(); - } catch (IOException e) { - throw new RuntimeException(e); - } - //build language detector: - languageDetector = LanguageDetectorBuilder.create(NgramExtractors.standard()) - .withProfiles(languageProfiles) - .build(); + static private Object initGuard = new Object(); + static private TextObjectFactory textObjectFactory = null; + static private LanguageDetector languageDetector = null; + + static private void initOptimaize (boolean useOptimaize) { + if (!useOptimaize) return; + synchronized (initGuard) { + if ((textObjectFactory != null) && (languageDetector != null)) return; + + // origin: https://github.com/optimaize/language-detector + //load all languages: + List<LanguageProfile> languageProfiles; + try { + languageProfiles = new LanguageProfileReader().readAllBuiltIn(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + //build language detector: + languageDetector = LanguageDetectorBuilder.create(NgramExtractors.standard()) + .withProfiles(languageProfiles) + .build(); - //create a text object factory - textObjectFactory = CommonTextObjectFactories.forDetectingOnLargeText(); + //create a text object factory + textObjectFactory = CommonTextObjectFactories.forDetectingOnLargeText(); + } } private final boolean enableOptimaize; + private SimpleDetector(boolean enableOptimaize) { + initOptimaize(enableOptimaize); + this.enableOptimaize = enableOptimaize; + + } + public SimpleDetector() { - this.enableOptimaize = true; + this(true); } public SimpleDetector(SimpleLinguisticsConfig.Detector detector) { - this.enableOptimaize = detector.enableOptimaize(); + this(detector.enableOptimaize()); } @Override diff --git a/logd/CMakeLists.txt b/logd/CMakeLists.txt index 3eeeb7adb66..6a8296564a3 100644 --- a/logd/CMakeLists.txt +++ b/logd/CMakeLists.txt @@ -18,3 +18,5 @@ vespa_define_module( src/tests/info src/tests/rotate ) + +vespa_install_script(src/apps/retention/retention-enforcer.sh vespa-retention-enforcer sbin) diff --git a/logd/src/apps/retention/retention-enforcer.sh b/logd/src/apps/retention/retention-enforcer.sh new file mode 100755 index 00000000000..7ab1b27d71a --- /dev/null +++ b/logd/src/apps/retention/retention-enforcer.sh @@ -0,0 +1,137 @@ +#!/bin/sh + +# daemon that collects old log files. +# global settings: + +DBGF=logs/vespa/debug.retention-enforcer +DBDIR=var/db/vespa/logfiledb +PIDF=$DBDIR/retention-enforcer.pid +RETAIN_DAYS=31 + +# this depends on components adding their log files +# to a "database" in DBDIR named "logfiles.TTTTT" where +# TTTTT is a timestamp in format (seconds/100000). +# The "database" holds lines with format "timestamp /path/to/logfile" +# where "timestamp" is just seconds since epoch. + +prereq_dir() { + if [ -d $1 ] && [ -w $1 ]; then + : + else + echo "$0: missing directory '$1' in '`pwd`'" >&2 + exit 1 + fi +} + +check_prereqs() { + prereq_dir var/db/vespa + prereq_dir logs/vespa +} + +ensure_dir () { + if [ -d $1 ] && [ -w $1 ]; then + return 0 + fi + echo "Creating directory '$1' in '`pwd`'" + mkdir -p $1 || exit 1 +} + +prepare_stuff() { + check_prereqs + exec > $DBGF.$$.log 2>&1 + ensure_dir $DBDIR +} + +bad_timestamp() { + now=$(date +%s) + if [ "$1" ] && [ "$1" -ge 1514764800 ] && [ "$1" -le $now ]; then + # sane timestamp: + return 1 + fi + # bad timestamp: + return 0 +} + +mark_pid() { + echo $$ > $PIDF.$$.tmp + mv $PIDF.$$.tmp $PIDF || exit 1 +} + +check_pidfile() { + read pid < $PIDF + [ "$pid" = $$ ] && return 0 + if [ "$pid" ] && [ $pid -gt $$ ]; then + sleep 30 + read pid_again < $PIDF + if [ "$pid_again" != "$pid" ]; then return 1; fi + ps -p $pid >/dev/null 2>&1 || return 1 + proc=$(ps -p $pid 2>&1) + case $proc in *retention*) ;; *) return 1;; esac + echo "$0 [$$]: Yielding my place to pid '$pid'" + exit 1 + fi +} + +get_mod_time() { + perl -e 'print (((stat("'"$1"'"))[9]) . "\n")' +} + +maybe_collect() { + timestamp=$1 + logfilename=$2 + + if bad_timestamp "$1"; then + echo "WARNING: bad timestamp '$timestamp' for logfilename '$logfilename'" + return + fi + + add=$((86400 * $RETAIN_DAYS)) + lim1=$(($timestamp + $add)) + mod_time=$(get_mod_time "$logfilename") + lim2=$(($mod_time + $add)) + + if [ $lim1 -lt $now ] && [ $lim2 -lt $now ]; then + echo "Collect logfile '$logfilename' timestamped $timestamp modified $mod_time" + rm -f "$logfilename" + fi +} + +process_file() { + dbfile="$1" + now=$(date +%s) + found=0 + while read timestamp logfilename; do + for fn in $logfilename $logfilename.*z*; do + if [ -f "$fn" ]; then + found=1 + maybe_collect "$timestamp" "$fn" + fi + done + done < $dbfile + if [ $found = 0 ]; then + ts=${dbfile##*.}99999 + maybe_collect "$ts" "$dbfile" + fi +} + +process_all() { + for dbf in $DBDIR/logfiles.* ; do + [ -f "$dbf" ] || continue + process_file "$dbf" + done +} + +mainloop() { + while true; do + mark_pid + process_all + sleep 3600 + check_pidfile + done +} + +# MAIN: + +prepare_stuff +mainloop +exit 0 diff --git a/logserver/pom.xml b/logserver/pom.xml index 55c5c443cfb..71eda0fb15b 100644 --- a/logserver/pom.xml +++ b/logserver/pom.xml @@ -20,11 +20,6 @@ <scope>test</scope> </dependency> <dependency> - <groupId>org.hamcrest</groupId> - <artifactId>hamcrest-core</artifactId> - <scope>test</scope> - </dependency> - <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>vespajlib</artifactId> <version>${project.version}</version> @@ -34,10 +29,6 @@ <artifactId>vespalog</artifactId> <version>${project.version}</version> </dependency> - <dependency> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-databind</artifactId> - </dependency> </dependencies> <build> <plugins> diff --git a/messagebus/src/main/java/com/yahoo/messagebus/ErrorCode.java b/messagebus/src/main/java/com/yahoo/messagebus/ErrorCode.java index 8794bd507a2..e54279e0541 100644 --- a/messagebus/src/main/java/com/yahoo/messagebus/ErrorCode.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/ErrorCode.java @@ -50,11 +50,6 @@ public final class ErrorCode { /** No services found for the message route. */ public static final int NO_SERVICES_FOR_ROUTE = FATAL_ERROR + 3; - /** The selected service was out of service. - */ - @Deprecated // Unused and will be removed - public static final int SERVICE_OOS = FATAL_ERROR + 4; - /** An error occured while encoding the message. */ public static final int ENCODE_ERROR = FATAL_ERROR + 5; @@ -118,7 +113,6 @@ public final class ErrorCode { case SEND_QUEUE_CLOSED : return "SEND_QUEUE_CLOSED"; case SEND_QUEUE_FULL : return "SEND_QUEUE_FULL"; case SEQUENCE_ERROR : return "SEQUENCE_ERROR"; - case SERVICE_OOS : return "SERVICE_OOS"; case SESSION_BUSY : return "SESSION_BUSY"; case TIMEOUT : return "TIMEOUT"; case TRANSIENT_ERROR : return "TRANSIENT_ERROR"; diff --git a/messagebus/src/main/java/com/yahoo/messagebus/Message.java b/messagebus/src/main/java/com/yahoo/messagebus/Message.java index 22496487f61..43f5c8d2dfd 100644 --- a/messagebus/src/main/java/com/yahoo/messagebus/Message.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/Message.java @@ -5,9 +5,9 @@ import com.yahoo.concurrent.SystemTimer; import com.yahoo.messagebus.routing.Route; /** - * <p>A message is a child of Routable, it is not a reply, and it has a sequencing identifier. Furthermore, a message + * A message is a child of Routable, it is not a reply, and it has a sequencing identifier. Furthermore, a message * contains a retry counter that holds what retry the message is currently on. See the method comment {@link #getRetry} - * for more information.</p> + * for more information. * * @author Simon Thoresen Hult */ diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java index 63514eca6dd..e21aeef1ee2 100755 --- a/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java @@ -20,15 +20,15 @@ public class Hop { private String cache = null; /** - * <p>Constructs an empty hop. You will need to add directives to the - * selector to make this usable.</p> + * Constructs an empty hop. You will need to add directives to the + * selector to make this usable. */ public Hop() { // empty } /** - * <p>Implements the copy constructor.</p> + * Implements the copy constructor. * * @param hop The hop to copy. */ @@ -38,8 +38,8 @@ public class Hop { } /** - * <p>Constructs a fully populated hop. This is package private and used by - * the {@link HopBlueprint#create()} method.</p> + * Constructs a fully populated hop. This is package private and used by + * the {@link HopBlueprint#create()} method. * * @param selector The selector to copy. * @param ignoreResult Whether or not to ignore the result of this hop. @@ -50,8 +50,8 @@ public class Hop { } /** - * <p>Parses the given string as a single hop. The {@link #toString()} - * method is compatible with this parser.</p> + * Parses the given string as a single hop. The {@link #toString()} + * method is compatible with this parser. * * @param str The string to parse. * @return A hop that corresponds to the string. @@ -65,8 +65,7 @@ public class Hop { } /** - * <p>Returns whether or not there are any directives contained in this - * hop.</p> + * Returns whether or not there are any directives contained in this hop. * * @return True if there is at least one directive. */ @@ -75,7 +74,7 @@ public class Hop { } /** - * <p>Returns the number of directives contained in this hop.</p> + * Returns the number of directives contained in this hop. * * @return The number of directives. */ @@ -84,7 +83,7 @@ public class Hop { } /** - * <p>Returns the directive at the given index.</p> + * Returns the directive at the given index. * * @param i The index of the directive to return. * @return The item. @@ -94,7 +93,7 @@ public class Hop { } /** - * <p>Adds a new directive to this hop.</p> + * Adds a new directive to this hop. * * @param directive The directive to add. * @return This, to allow chaining. @@ -106,7 +105,7 @@ public class Hop { } /** - * <p>Sets the directive at a given index.</p> + * Sets the directive at a given index. * * @param i The index at which to set the directive. * @param directive The directive to set. @@ -283,9 +282,10 @@ public class Hop { @Override public int hashCode() { - int result = selector != null ? selector.hashCode() : 0; + int result = selector.hashCode(); result = 31 * result + (ignoreResult ? 1 : 0); result = 31 * result + (cache != null ? cache.hashCode() : 0); return result; } + } diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java index 809b2da69c4..838b11e7a02 100755 --- a/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java @@ -14,13 +14,14 @@ public interface HopDirective { * @param dir The directive to compare this to. * @return True if this matches the argument. */ - public boolean matches(HopDirective dir); + boolean matches(HopDirective dir); /** * Returns a string representation of this that can be debugged but not parsed. * * @return The debug string. */ - public String toDebugString(); + String toDebugString(); + } diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java index a07c6e16100..9190b680ebf 100755 --- a/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java @@ -20,7 +20,7 @@ import java.util.List; */ public class Route { - private final List<Hop> hops = new ArrayList<Hop>(); + private final List<Hop> hops = new ArrayList<>(); private String cache = null; /** diff --git a/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp b/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp index 108b94070bf..b72416f51d2 100644 --- a/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp +++ b/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp @@ -128,7 +128,6 @@ RPCNetwork::RPCNetwork(const RPCNetworkParams ¶ms) : _allowDispatchForEncode(params.getDispatchOnEncode()), _allowDispatchForDecode(params.getDispatchOnDecode()) { - _transport->SetDirectWrite(false); _transport->SetMaxInputBufferSize(params.getMaxInputBufferSize()); _transport->SetMaxOutputBufferSize(params.getMaxOutputBufferSize()); } @@ -188,7 +187,7 @@ RPCNetwork::attach(INetworkOwner &owner) _sendAdapters[vespalib::Version(6, 149)] = _sendV2.get(); FRT_ReflectionBuilder builder(_orb.get()); - builder.DefineMethod("mbus.getVersion", "", "s", true, FRT_METHOD(RPCNetwork::invoke), this); + builder.DefineMethod("mbus.getVersion", "", "s", FRT_METHOD(RPCNetwork::invoke), this); builder.MethodDesc("Retrieves the message bus version."); builder.ReturnDesc("version", "The message bus version."); } diff --git a/messagebus/src/vespa/messagebus/network/rpcsendv1.cpp b/messagebus/src/vespa/messagebus/network/rpcsendv1.cpp index 6b89a278b88..376267b555c 100644 --- a/messagebus/src/vespa/messagebus/network/rpcsendv1.cpp +++ b/messagebus/src/vespa/messagebus/network/rpcsendv1.cpp @@ -35,7 +35,7 @@ RPCSendV1::getReturnSpec() const { void RPCSendV1::build(FRT_ReflectionBuilder & builder) { - builder.DefineMethod(METHOD_NAME, METHOD_PARAMS, METHOD_RETURN, true, FRT_METHOD(RPCSendV1::invoke), this); + builder.DefineMethod(METHOD_NAME, METHOD_PARAMS, METHOD_RETURN, FRT_METHOD(RPCSendV1::invoke), this); builder.MethodDesc("Send a message bus request and get a reply back."); builder.ParamDesc("version", "The version of the message."); builder.ParamDesc("route", "Names of additional hops to visit."); diff --git a/messagebus/src/vespa/messagebus/network/rpcsendv2.cpp b/messagebus/src/vespa/messagebus/network/rpcsendv2.cpp index 4c04549aee1..91a41a6a800 100644 --- a/messagebus/src/vespa/messagebus/network/rpcsendv2.cpp +++ b/messagebus/src/vespa/messagebus/network/rpcsendv2.cpp @@ -59,7 +59,7 @@ bool RPCSendV2::isCompatible(stringref method, stringref request, stringref resp void RPCSendV2::build(FRT_ReflectionBuilder & builder) { - builder.DefineMethod(METHOD_NAME, METHOD_PARAMS, METHOD_RETURN, true, FRT_METHOD(RPCSendV2::invoke), this); + builder.DefineMethod(METHOD_NAME, METHOD_PARAMS, METHOD_RETURN, FRT_METHOD(RPCSendV2::invoke), this); builder.MethodDesc("Send a message bus slime request and get a reply back."); builder.ParamDesc("header_encoding", "0=raw, 6=lz4"); builder.ParamDesc("header_decoded_size", "Uncompressed header blob size"); diff --git a/model-evaluation/OWNERS b/model-evaluation/OWNERS new file mode 100644 index 00000000000..2bd865cff34 --- /dev/null +++ b/model-evaluation/OWNERS @@ -0,0 +1,2 @@ +bratseth +lesters diff --git a/model-evaluation/README b/model-evaluation/README new file mode 100644 index 00000000000..0bf143a2804 --- /dev/null +++ b/model-evaluation/README @@ -0,0 +1,6 @@ +Provides +- an injectable component (ai.vespa.models.evaluation.ModelsEvaluator) which allows direct, stateless evaluation of + any machine learned models added to the models/ directory in any container. +- a handler (turned on with the <models-evaluation> tag in <container>) which provides the models-evaluation REST + API which provides stateless (single data point) model evaluation over HTTP(S). + diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml index 6fdc25f3786..328d475c501 100644 --- a/model-evaluation/pom.xml +++ b/model-evaluation/pom.xml @@ -1,5 +1,5 @@ <?xml version="1.0"?> -<!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 @@ -40,6 +40,11 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>searchcore</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>config</artifactId> <version>${project.version}</version> <scope>provided</scope> @@ -75,22 +80,6 @@ <artifactId>bundle-plugin</artifactId> <extensions>true</extensions> </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-jar-plugin</artifactId> - <configuration> - <archive> - <manifestEntries> - <Bundle-SymbolicName>${project.artifactId}</Bundle-SymbolicName> - <Vespa-Version>${project.version}</Vespa-Version> - </manifestEntries> - </archive> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-compiler-plugin</artifactId> - </plugin> </plugins> </build> </project> diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java new file mode 100644 index 00000000000..e664693ab38 --- /dev/null +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java @@ -0,0 +1,27 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.tensor.Tensor; + +/** + * A named constant loaded from a file. + * + * This is immutable. + * + * @author bratseth + */ +class Constant { + + private final String name; + private final Tensor value; + + Constant(String name, Tensor value) { + this.name = name; + this.value = value; + } + + public String name() { return name; } + + public Tensor value() { return value; } + +} diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 520986ffb77..e08b9f77d15 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -56,4 +56,6 @@ public class FunctionEvaluator { return function.getBody().evaluate(context).asTensor(); } + LazyArrayContext context() { return context; } + } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 2dcfd204077..beaa36b898f 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -16,6 +17,7 @@ import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -37,8 +39,11 @@ final class LazyArrayContext extends Context implements ContextIndex { * * @param expression the expression to create a context for */ - LazyArrayContext(RankingExpression expression, Map<FunctionReference, ExpressionFunction> functions, Model model) { - this.indexedBindings = new IndexedBindings(expression, functions, this, model); + LazyArrayContext(RankingExpression expression, + Map<FunctionReference, ExpressionFunction> functions, + List<Constant> constants, + Model model) { + this.indexedBindings = new IndexedBindings(expression, functions, constants, this, model); } /** @@ -139,8 +144,10 @@ final class LazyArrayContext extends Context implements ContextIndex { */ IndexedBindings(RankingExpression expression, Map<FunctionReference, ExpressionFunction> functions, + List<Constant> constants, LazyArrayContext owner, Model model) { + // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); extractBindTargets(expression.getRoot(), functions, bindTargets); @@ -150,9 +157,18 @@ final class LazyArrayContext extends Context implements ContextIndex { int i = 0; ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); for (String variable : bindTargets) - nameToIndexBuilder.put(variable,i++); + nameToIndexBuilder.put(variable, i++); nameToIndex = nameToIndexBuilder.build(); + + // 2. Bind the bind targets + for (Constant constant : constants) { + String constantReference = "constant(" + constant.name() + ")"; + Integer index = nameToIndex.get(constantReference); + if (index != null) + values[index] = new TensorValue(constant.value()); + } + for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { Integer index = nameToIndex.get(function.getKey().serialForm()); if (index != null) // Referenced in this, so bind it @@ -170,7 +186,7 @@ final class LazyArrayContext extends Context implements ContextIndex { extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets); } else if (isConstant(node)) { - // Ignore + bindTargets.add(node.toString()); } else if (node instanceof ReferenceNode) { bindTargets.add(node.toString()); @@ -193,7 +209,7 @@ final class LazyArrayContext extends Context implements ContextIndex { if ( ! (node instanceof ReferenceNode)) return false; ReferenceNode reference = (ReferenceNode)node; - return reference.getName().equals("value") && reference.getArguments().size() == 1; + return reference.getName().equals("constant") && reference.getArguments().size() == 1; } Value get(int index) { return values[index]; } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 95eb923786d..3fb43d73187 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -36,11 +36,15 @@ public class Model { private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); + /** Programmatically create a model containing functions without constant of function references only */ public Model(String name, Collection<ExpressionFunction> functions) { - this(name, functions, Collections.emptyMap()); + this(name, functions, Collections.emptyMap(), Collections.emptyList()); } - Model(String name, Collection<ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions) { + Model(String name, + Collection<ExpressionFunction> functions, + Map<FunctionReference, ExpressionFunction> referencedFunctions, + List<Constant> constants) { // TODO: Optimize functions this.name = name; this.functions = ImmutableList.copyOf(functions); @@ -48,7 +52,8 @@ public class Model { ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (ExpressionFunction function : functions) { try { - contextBuilder.put(function.getName(), new LazyArrayContext(function.getBody(), referencedFunctions, this)); + contextBuilder.put(function.getName(), + new LazyArrayContext(function.getBody(), referencedFunctions, constants, this)); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index dacf20b7ef2..a0b859bf930 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -3,8 +3,11 @@ package ai.vespa.models.evaluation; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableMap; +import com.google.inject.Inject; import com.yahoo.component.AbstractComponent; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.util.Map; import java.util.stream.Collectors; @@ -21,8 +24,15 @@ public class ModelsEvaluator extends AbstractComponent { private final ImmutableMap<String, Model> models; - public ModelsEvaluator(RankProfilesConfig config) { - models = ImmutableMap.copyOf(new RankProfilesConfigImporter().importFrom(config)); + @Inject + public ModelsEvaluator(RankProfilesConfig config, + RankingConstantsConfig constantsConfig, + FileAcquirer fileAcquirer) { + this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig)); + } + + public ModelsEvaluator(Map<String, Model> models) { + this.models = ImmutableMap.copyOf(models); } /** Returns the models of this as an immutable map */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index bfd6342218a..d2fca309a19 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,33 +1,54 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.yahoo.config.FileReference; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; /** - * Converts RankProfilesConfig instances to RankingExpressions for evaluation + * Converts RankProfilesConfig instances to RankingExpressions for evaluation. + * This class can be used by a single thread only. * * @author bratseth */ -class RankProfilesConfigImporter { +public class RankProfilesConfigImporter { + + private final FileAcquirer fileAcquirer; + + public RankProfilesConfigImporter(FileAcquirer fileAcquirer) { + this.fileAcquirer = fileAcquirer; + } /** * Returns a map of the models contained in this config, indexed on name. * The map is modifiable and owned by the caller. */ - Map<String, Model> importFrom(RankProfilesConfig config) { + public Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { try { Map<String, Model> models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { - Model model = importProfile(profile); + Model model = importProfile(profile, constantsConfig); models.put(model.name(), model); } return models; @@ -37,11 +58,16 @@ class RankProfilesConfigImporter { } } - private Model importProfile(RankProfilesConfig.Rankprofile profile) throws ParseException { + private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) + throws ParseException { List<ExpressionFunction> functions = new ArrayList<>(); Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>(); + SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo(); ExpressionFunction firstPhase = null; ExpressionFunction secondPhase = null; + + List<Constant> constants = readLargeConstants(constantsConfig); + for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); if ( reference.isPresent()) { @@ -52,7 +78,8 @@ class RankProfilesConfigImporter { functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); // // Make all functions, bound or not available under the name they are referenced by in expressions - referencedFunctions.put(reference.get(), new ExpressionFunction(reference.get().serialForm(), arguments, expression)); + referencedFunctions.put(reference.get(), + new ExpressionFunction(reference.get().serialForm(), arguments, expression)); } else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to macros firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(), @@ -62,14 +89,19 @@ class RankProfilesConfigImporter { secondPhase = new ExpressionFunction("secondphase", new ArrayList<>(), new RankingExpression("second-phase", property.value())); } + else { + smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value()); + } } if (functionByName("firstphase", functions) == null && firstPhase != null) // may be already included, depending on body functions.add(firstPhase); if (functionByName("secondphase", functions) == null && secondPhase != null) // may be already included, depending on body functions.add(secondPhase); + constants.addAll(smallConstantsInfo.asConstants()); + try { - return new Model(profile.name(), functions, referencedFunctions); + return new Model(profile.name(), functions, referencedFunctions, constants); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); @@ -83,4 +115,73 @@ class RankProfilesConfigImporter { return null; } + private List<Constant> readLargeConstants(RankingConstantsConfig constantsConfig) { + List<Constant> constants = new ArrayList<>(); + + for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) { + constants.add(new Constant(constantConfig.name(), + readTensorFromFile(constantConfig.name(), + TensorType.fromSpec(constantConfig.type()), + constantConfig.fileref()))); + } + return constants; + } + + protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) { + try { + File file = fileAcquirer.waitFor(fileReference, 7, TimeUnit.DAYS); + if (file.getName().endsWith(".tbf")) + return TypedBinaryFormat.decode(Optional.of(type), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(file))); + else + throw new IllegalArgumentException("Constant files on other formats than .tbf are not supported, got " + + file + " for constant " + name); + // TODO: Support json and json.lz4 + } + catch (InterruptedException e) { + throw new IllegalStateException("Gave up waiting for constant " + name); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** Collected information about small constants */ + private static class SmallConstantsInfo { + + private static final Pattern valuePattern = Pattern.compile("constant\\(([a-zA-Z0-9_.]+)\\)\\.value"); + private static final Pattern typePattern = Pattern.compile("constant\\(([a-zA-Z0-9_.]+)\\)\\.type"); + + private Map<String, TensorType> types = new HashMap<>(); + private Map<String, String> values = new HashMap<>(); + + void addIfSmallConstantInfo(String key, String value) { + tryValue(key, value); + tryType(key, value); + } + + private void tryValue(String key, String value) { + Matcher matcher = valuePattern.matcher(key); + if (matcher.matches()) + values.put(matcher.group(1), value); + } + + private void tryType(String key, String value) { + Matcher matcher = typePattern.matcher(key); + if (matcher.matches()) + types.put(matcher.group(1), TensorType.fromSpec(value)); + } + + List<Constant> asConstants() { + List<Constant> constants = new ArrayList<>(); + for (Map.Entry<String, String> entry : values.entrySet()) { + TensorType type = types.get(entry.getKey()); + if (type == null) throw new IllegalStateException("Missing type of '" + entry.getKey() + "'"); // Won't happen + constants.add(new Constant(entry.getKey(), Tensor.from(type, entry.getValue()))); + } + return constants; + } + + } + } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java new file mode 100644 index 00000000000..6e55c0c9a53 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -0,0 +1,74 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import org.junit.Test; + +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; + +/** + * Tests instantiating models from rank-profiles configs. + * + * @author bratseth + */ +public class MlModelsImportingTest { + + private static final double delta = 0.00000000001; + + @Test + public void testImportingModels() { + ModelTester tester = new ModelTester("src/test/resources/config/models/"); + + assertEquals(4, tester.models().size()); + + // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that + { + Model xgboost = tester.models().get("xgboost_2_2"); + tester.assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + FunctionEvaluator evaluator = xgboost.evaluatorOf(); + assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta); + } + + { + + Model onnxMnistSoftmax = tester.models().get("mnist_softmax"); + tester.assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10],d2[784])", + onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); + FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); + } + + { + Model tfMnistSoftmax = tester.models().get("mnist_softmax_saved"); + tester.assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); + } + + { + Model tfMnist = tester.models().get("mnist_saved"); + tester.assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + // Macro: + tester.assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add", + "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", + tfMnist); + FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument + assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-0.714629131972222, evaluator.evaluate().sum().asDouble(), delta); + } + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java new file mode 100644 index 00000000000..0aceaccc3e0 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -0,0 +1,94 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.config.FileReference; +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.FileSource; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * Helper for testing model import and evaluation + * + * @author bratseth + */ +public class ModelTester { + + private final Map<String, Model> models; + + public ModelTester(String modelConfigDirectory) { + models = createModels(modelConfigDirectory); + } + + public Map<String, Model> models() { return models; } + + private static Map<String, Model> createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"), MockFileAcquirer.returnFile(null)) + .importFrom(config, constantsConfig); + } + + public void assertFunction(String name, String expression, Model model) { + assertNotNull("Model is present in config", model); + ExpressionFunction function = model.function(name); + assertNotNull("Function '" + name + "' is in " + model, function); + assertEquals(name, function.getName()); + assertEquals(expression, function.getBody().getRoot().toString()); + } + + public void assertBoundFunction(String name, String expression, Model model) { + ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get()); + assertNotNull("Function '" + name + "' is present", function); + assertEquals(name, function.getName()); + assertEquals(expression, function.getBody().getRoot().toString()); + } + + /** Allows us to provide canned tensor constants during import since file distribution does not work in tests */ + private static class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter { + + private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName()); + + private final Path constantsPath; + + public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) { + super(fileAcquirer); + this.constantsPath = constantsPath; + } + + @Override + protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) { + try { + return TypedBinaryFormat.decode(Optional.of(type), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantsPath.append(name).toFile()))); + } + catch (IOException e) { + log.warning("Missing a mocked tensor constant for '" + name + "': " + e.getMessage() + + ". Returning an empty tensor"); + return Tensor.from(type, "{}"); + } + } + + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index 60cf0d25ded..bd1ff6b8ed7 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -3,12 +3,13 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; +import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; -import java.io.File; - import static org.junit.Assert.assertEquals; /** @@ -18,15 +19,9 @@ public class ModelsEvaluatorTest { private static final double delta = 0.00000000001; - private ModelsEvaluator createModels() { - String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; - RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); - return new ModelsEvaluator(config); - } - @Test public void testTensorEvaluation() { - ModelsEvaluator models = createModels(); + ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum"); function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}")); function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}")); @@ -35,7 +30,7 @@ public class ModelsEvaluatorTest { @Test public void testEvaluationDependingOnMacroTakingArguments() { - ModelsEvaluator models = createModels(); + ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); function.bind("match", 3); function.bind("rankBoost", 5); @@ -46,6 +41,14 @@ public class ModelsEvaluatorTest { // TODO: Test that binding nonexisting variable doesn't work // TODO: Test that rebinding doesn't work // TODO: Test with nested macros - // TODO: Test TF/ONNX model + + private ModelsEvaluator createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new ModelsEvaluator(config, constantsConfig, MockFileAcquirer.returnFile(null)); + } } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java new file mode 100644 index 00000000000..20abd9c0fb0 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java @@ -0,0 +1,34 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class RankProfileImportingTest { + + @Test + public void testImportingRankExpressions() { + ModelTester tester = new ModelTester("src/test/resources/config/rankexpression/"); + + assertEquals(18, tester.models().size()); + + Model macros = tester.models().get("macros"); + assertEquals("macros", macros.name()); + assertEquals(4, macros.functions().size()); + tester.assertFunction("fourtimessum", "4 * (var1 + var2)", macros); + tester.assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros); + tester.assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros); + tester.assertFunction("myfeature", + "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " + + "30 * pow(0 - fieldMatch(description).earliness,2)", + macros); + assertEquals(4, macros.referencedFunctions().size()); + tester.assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", + "4 * (match + rankBoost)", macros); + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java deleted file mode 100644 index d45372fc7da..00000000000 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.models.evaluation; - -import com.yahoo.config.subscription.ConfigGetter; -import com.yahoo.config.subscription.FileSource; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.vespa.config.search.RankProfilesConfig; -import org.junit.Test; - -import java.io.File; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - -/** - * Tests instantiating models from rank-profiles configs. - * - * @author bratseth - */ -public class RankProfilesImporterTest { - - @Test - public void testImporting() { - String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; - RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); - Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config); - assertEquals(18, models.size()); - - Model macros = models.get("macros"); - assertNotNull(macros); - assertEquals("macros", macros.name()); - assertEquals(4, macros.functions().size()); - assertFunction("fourtimessum", "4 * (var1 + var2)", macros); - assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros); - assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros); - assertFunction("myfeature", - "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " + - "30 * pow(0 - fieldMatch(description).earliness,2)", - macros); - assertEquals(4, macros.referencedFunctions().size()); - assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", - "4 * (match + rankBoost)", macros); - } - - private void assertFunction(String name, String expression, Model model) { - ExpressionFunction function = model.function(name); - assertNotNull(function); - assertEquals(name, function.getName()); - assertEquals(expression, function.getBody().getRoot().toString()); - } - - private void assertBoundFunction(String name, String expression, Model model) { - ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get()); - assertNotNull("Function '" + name + "' is present", function); - assertEquals(name, function.getName()); - assertEquals(expression, function.getBody().getRoot().toString()); - } - -} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/SmallConstantImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/SmallConstantImportingTest.java new file mode 100644 index 00000000000..5fd2e1de4ed --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/SmallConstantImportingTest.java @@ -0,0 +1,25 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class SmallConstantImportingTest { + + @Test + public void testImportingSmallConstant() { + ModelTester tester = new ModelTester("src/test/resources/config/smallconstant/"); + + assertEquals(1, tester.models().size()); + + Model model = tester.models().get("my_profile"); + tester.assertFunction("firstphase", "reduce(constant(my_tensor), sum)", model); + assertEquals(3.0, model.evaluatorOf().evaluate().asDouble(), 0.00000000001); + + } + +} diff --git a/model-evaluation/src/test/resources/config/models/constants/README b/model-evaluation/src/test/resources/config/models/constants/README new file mode 100644 index 00000000000..4a274aa95c8 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/README @@ -0,0 +1 @@ +These constants was created by writing TypedBinaryFormat.encode(tensor) on each large constant produced by these models. diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_bias_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_bias_read Binary files differnew file mode 100644 index 00000000000..bac75f7b1e7 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_bias_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_weights_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_weights_read Binary files differnew file mode 100644 index 00000000000..bd3f05be826 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_weights_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_bias_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_bias_read Binary files differnew file mode 100644 index 00000000000..fca7c76df3f --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_bias_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_weights_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_weights_read Binary files differnew file mode 100644 index 00000000000..396dea8f4bc --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_weights_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_bias_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_bias_read Binary files differnew file mode 100644 index 00000000000..42f85478c10 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_bias_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_weights_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_weights_read Binary files differnew file mode 100644 index 00000000000..a3cc7d765f6 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_weights_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable Binary files differnew file mode 100644 index 00000000000..e768328bff5 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable_1 b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable_1 Binary files differnew file mode 100644 index 00000000000..4fa0eadb0d3 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable_1 diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read Binary files differnew file mode 100644 index 00000000000..4fa0eadb0d3 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read Binary files differnew file mode 100644 index 00000000000..e768328bff5 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg new file mode 100644 index 00000000000..1cc36f75158 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg @@ -0,0 +1,14 @@ +rankprofile[0].name "mnist_saved" +rankprofile[0].fef.property[0].name "rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add).rankingScript" +rankprofile[0].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))" +rankprofile[0].fef.property[1].name "rankingExpression(serving_default.y).rankingScript" +rankprofile[0].fef.property[1].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))" +rankprofile[1].name "xgboost_2_2" +rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript" +rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)" +rankprofile[2].name "mnist_softmax_saved" +rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript" +rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))" +rankprofile[3].name "mnist_softmax" +rankprofile[3].fef.property[0].name "rankingExpression(default.add).rankingScript" +rankprofile[3].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))" diff --git a/model-evaluation/src/test/resources/config/models/ranking-constants.cfg b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg new file mode 100644 index 00000000000..2b7495ace5e --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg @@ -0,0 +1,30 @@ +constant[0].name "mnist_saved_dnn_hidden1_weights_read" +constant[0].fileref "" +constant[0].type "tensor(d3[300],d4[784])" +constant[1].name "mnist_saved_dnn_hidden2_weights_read" +constant[1].fileref "" +constant[1].type "tensor(d2[100],d3[300])" +constant[2].name "mnist_softmax_saved_layer_Variable_1_read" +constant[2].fileref "" +constant[2].type "tensor(d1[10])" +constant[3].name "mnist_saved_dnn_hidden1_bias_read" +constant[3].fileref "" +constant[3].type "tensor(d3[300])" +constant[4].name "mnist_saved_dnn_hidden2_bias_read" +constant[4].fileref "" +constant[4].type "tensor(d2[100])" +constant[5].name "mnist_softmax_Variable" +constant[5].fileref "" +constant[5].type "tensor(d1[10],d2[784])" +constant[6].name "mnist_saved_dnn_outputs_weights_read" +constant[6].fileref "" +constant[6].type "tensor(d1[10],d2[100])" +constant[7].name "mnist_softmax_saved_layer_Variable_read" +constant[7].fileref "" +constant[7].type "tensor(d1[10],d2[784])" +constant[8].name "mnist_softmax_Variable_1" +constant[8].fileref "" +constant[8].type "tensor(d1[10])" +constant[9].name "mnist_saved_dnn_outputs_bias_read" +constant[9].fileref "" +constant[9].type "tensor(d1[10])"
\ No newline at end of file diff --git a/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg diff --git a/model-evaluation/src/test/resources/config/smallconstant/rank-profiles.cfg b/model-evaluation/src/test/resources/config/smallconstant/rank-profiles.cfg new file mode 100644 index 00000000000..840f1a06296 --- /dev/null +++ b/model-evaluation/src/test/resources/config/smallconstant/rank-profiles.cfg @@ -0,0 +1,9 @@ +rankprofile[0].name "my_profile" +rankprofile[0].fef.property[0].name "constant(my_tensor).type" +rankprofile[0].fef.property[0].value "tensor(x{},y{})" +rankprofile[0].fef.property[1].name "constant(my_tensor).value" +rankprofile[0].fef.property[1].value "{{x:1,y:2}:1.0,{x:2,y:1}:2.0}" +rankprofile[0].fef.property[2].name "vespa.rank.firstphase" +rankprofile[0].fef.property[2].value "rankingExpression(firstphase)" +rankprofile[0].fef.property[3].name "rankingExpression(firstphase).rankingScript" +rankprofile[0].fef.property[3].value "reduce(constant(my_tensor), sum)" diff --git a/model-evaluation/src/test/resources/config/smallconstant/ranking-constants.cfg b/model-evaluation/src/test/resources/config/smallconstant/ranking-constants.cfg new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/model-evaluation/src/test/resources/config/smallconstant/ranking-constants.cfg diff --git a/model-evaluation/src/test/resources/config/smallconstant/smallconstant.sd b/model-evaluation/src/test/resources/config/smallconstant/smallconstant.sd new file mode 100644 index 00000000000..47d2d95b968 --- /dev/null +++ b/model-evaluation/src/test/resources/config/smallconstant/smallconstant.sd @@ -0,0 +1,18 @@ +search smallconstant { + + document smallconstant { + } + + rank-profile my_profile { + first-phase { + expression: sum(my_tensor) + } + constants { + my_tensor + value: { {x:1,y:2}:1, {x:2,y:1}:2 } + type: tensor(x{},y{}) + } + } + } + +}
\ No newline at end of file diff --git a/node-admin/pom.xml b/node-admin/pom.xml index 7daeacec463..64958554f53 100644 --- a/node-admin/pom.xml +++ b/node-admin/pom.xml @@ -18,6 +18,7 @@ <name>${project.artifactId}</name> <dependencies> + <!-- Provided --> <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>docker-api</artifactId> @@ -32,49 +33,55 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>node-repository</artifactId> + <artifactId>defaults</artifactId> <version>${project.version}</version> + <scope>provided</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>defaults</artifactId> + <artifactId>container-dev</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>container-dev</artifactId> + <artifactId>vespa-athenz</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> + + <!-- Compile --> <dependency> - <groupId>net.jpountz.lz4</groupId> - <artifactId>lz4</artifactId> + <groupId>com.yahoo.vespa</groupId> + <artifactId>orchestrator-restapi</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>node-repository</artifactId> + <version>${project.version}</version> <scope>compile</scope> </dependency> <dependency> <groupId>org.apache.httpcomponents</groupId> <artifactId>httpcore</artifactId> <version>4.4.1</version> + <scope>compile</scope> </dependency> <dependency> <groupId>org.apache.httpcomponents</groupId> <artifactId>httpclient</artifactId> <version>4.5</version> - </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>orchestrator-restapi</artifactId> - <version>${project.version}</version> <scope>compile</scope> </dependency> <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>vespa-athenz</artifactId> - <version>${project.version}</version> - <scope>provided</scope> + <groupId>org.apache.velocity</groupId> + <artifactId>velocity</artifactId> + <scope>compile</scope> </dependency> + <!-- Test --> <dependency> <groupId>org.hamcrest</groupId> <artifactId>hamcrest-junit</artifactId> @@ -82,6 +89,11 @@ <scope>test</scope> </dependency> <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <scope>test</scope> @@ -89,24 +101,24 @@ <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>application</artifactId> - <scope>test</scope> <version>${project.version}</version> + <scope>test</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>orchestrator</artifactId> + <artifactId>application-model</artifactId> <version>${project.version}</version> <scope>test</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>service-monitor</artifactId> + <artifactId>orchestrator</artifactId> <version>${project.version}</version> <scope>test</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>application-model</artifactId> + <artifactId>service-monitor</artifactId> <version>${project.version}</version> <scope>test</scope> </dependency> @@ -116,16 +128,6 @@ <version>${project.version}</version> <scope>test</scope> </dependency> - <dependency> - <groupId>org.mockito</groupId> - <artifactId>mockito-core</artifactId> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.apache.velocity</groupId> - <artifactId>velocity</artifactId> - <scope>compile</scope> - </dependency> </dependencies> <build> <plugins> diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/Environment.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/Environment.java index c9f17b7cbf6..dc0ac0df05d 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/Environment.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/Environment.java @@ -8,22 +8,21 @@ import com.yahoo.vespa.athenz.utils.AthenzIdentities; import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.hosted.dockerapi.ContainerName; import com.yahoo.vespa.hosted.node.admin.config.ConfigServerConfig; +import com.yahoo.vespa.hosted.node.admin.docker.DockerNetworking; import com.yahoo.vespa.hosted.node.admin.task.util.network.IPAddresses; import com.yahoo.vespa.hosted.node.admin.task.util.network.IPAddressesImpl; import java.net.URI; import java.nio.file.Path; import java.nio.file.Paths; -import java.text.DateFormat; -import java.text.SimpleDateFormat; import java.time.Instant; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; import java.util.Arrays; import java.util.Collections; -import java.util.Date; import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.TimeZone; /** * Various utilities for getting values from node-admin's environment. Immutable. @@ -32,7 +31,8 @@ import java.util.TimeZone; * @author hmusum */ public class Environment { - private static final DateFormat filenameFormatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS"); + private static final DateTimeFormatter filenameFormatter = DateTimeFormatter + .ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS").withZone(ZoneOffset.UTC); public static final String APPLICATION_STORAGE_CLEANUP_PATH_PREFIX = "cleanup_"; private static final String ENVIRONMENT = "ENVIRONMENT"; @@ -51,13 +51,13 @@ public class Environment { private final String environment; private final String region; private final String system; + private final String cloud; private final String parentHostHostname; private final IPAddresses ipAddresses; private final PathResolver pathResolver; private final List<String> logstashNodes; private final Optional<String> coredumpFeedEndpoint; private final NodeType nodeType; - private final String cloud; private final ContainerEnvironmentResolver containerEnvironmentResolver; private final String certificateDnsSuffix; private final URI ztsUri; @@ -65,10 +65,7 @@ public class Environment { private final boolean nodeAgentCertEnabled; private final boolean isRunningOnHost; private final Path trustStorePath; - - static { - filenameFormatter.setTimeZone(TimeZone.getTimeZone("UTC")); - } + private final DockerNetworking dockerNetworking; public Environment(ConfigServerConfig configServerConfig) { this(configServerConfig, @@ -76,19 +73,20 @@ public class Environment { getEnvironmentVariable(ENVIRONMENT), getEnvironmentVariable(REGION), getEnvironmentVariable(SYSTEM), + getEnvironmentVariable(CLOUD), Defaults.getDefaults().vespaHostname(), new IPAddressesImpl(), new PathResolver(), getLogstashNodesFromEnvironment(), Optional.of(getEnvironmentVariable(COREDUMP_FEED_ENDPOINT)), NodeType.host, - getEnvironmentVariable(CLOUD), new DefaultContainerEnvironmentResolver(), getEnvironmentVariable(CERTIFICATE_DNS_SUFFIX), URI.create(getEnvironmentVariable(ZTS_URI)), (AthenzService)AthenzIdentities.from(getEnvironmentVariable(NODE_ATHENZ_IDENTITY)), Boolean.valueOf(getEnvironmentVariable(ENABLE_NODE_AGENT_CERT)), - false); + false, + DockerNetworking.MACVLAN); } private Environment(ConfigServerConfig configServerConfig, @@ -96,36 +94,33 @@ public class Environment { String environment, String region, String system, + String cloud, String parentHostHostname, IPAddresses ipAddresses, PathResolver pathResolver, List<String> logstashNodes, Optional<String> coreDumpFeedEndpoint, NodeType nodeType, - String cloud, ContainerEnvironmentResolver containerEnvironmentResolver, String certificateDnsSuffix, URI ztsUri, AthenzService nodeAthenzIdentity, boolean nodeAgentCertEnabled, - boolean isRunningOnHost) { + boolean isRunningOnHost, + DockerNetworking dockerNetworking) { Objects.requireNonNull(configServerConfig, "configServerConfig cannot be null"); - Objects.requireNonNull(environment, "environment cannot be null"); - Objects.requireNonNull(region, "region cannot be null"); - Objects.requireNonNull(system, "system cannot be null"); - Objects.requireNonNull(cloud, "cloud cannot be null"); this.configServerInfo = new ConfigServerInfo(configServerConfig); - this.environment = environment; - this.region = region; - this.system = system; + this.environment = Objects.requireNonNull(environment, "environment cannot be null");; + this.region = Objects.requireNonNull(region, "region cannot be null");; + this.system = Objects.requireNonNull(system, "system cannot be null");; + this.cloud = Objects.requireNonNull(cloud, "cloud cannot be null"); this.parentHostHostname = parentHostHostname; this.ipAddresses = ipAddresses; this.pathResolver = pathResolver; this.logstashNodes = logstashNodes; this.coredumpFeedEndpoint = coreDumpFeedEndpoint; this.nodeType = nodeType; - this.cloud = cloud; this.containerEnvironmentResolver = containerEnvironmentResolver; this.certificateDnsSuffix = certificateDnsSuffix; this.ztsUri = ztsUri; @@ -133,6 +128,7 @@ public class Environment { this.nodeAgentCertEnabled = nodeAgentCertEnabled; this.isRunningOnHost = isRunningOnHost; this.trustStorePath = trustStorePath; + this.dockerNetworking = Objects.requireNonNull(dockerNetworking, "dockerNetworking cannot be null"); } public List<String> getConfigServerHostNames() { return configServerInfo.getConfigServerHostNames(); } @@ -147,6 +143,8 @@ public class Environment { return system; } + public String getCloud() { return cloud; } + public String getParentHostHostname() { return parentHostHostname; } @@ -196,7 +194,7 @@ public class Environment { public Path pathInNodeAdminToNodeCleanup(ContainerName containerName) { return pathResolver.getApplicationStoragePathForNodeAdmin() .resolve(APPLICATION_STORAGE_CLEANUP_PATH_PREFIX + containerName.asString() + - "_" + filenameFormatter.format(Date.from(Instant.now()))); + "_" + filenameFormatter.format(Instant.now())); } /** @@ -242,8 +240,6 @@ public class Environment { public NodeType getNodeType() { return nodeType; } - public String getCloud() { return cloud; } - public ContainerEnvironmentResolver getContainerEnvironmentResolver() { return containerEnvironmentResolver; } @@ -280,18 +276,22 @@ public class Environment { return isRunningOnHost; } + public DockerNetworking getDockerNetworking() { + return dockerNetworking; + } + public static class Builder { private ConfigServerConfig configServerConfig; private String environment; private String region; private String system; + private String cloud; private String parentHostHostname; private IPAddresses ipAddresses; private PathResolver pathResolver; private List<String> logstashNodes = Collections.emptyList(); private Optional<String> coredumpFeedEndpoint = Optional.empty(); private NodeType nodeType = NodeType.tenant; - private String cloud; private ContainerEnvironmentResolver containerEnvironmentResolver; private String certificateDnsSuffix; private URI ztsUri; @@ -299,6 +299,7 @@ public class Environment { private boolean nodeAgentCertEnabled; private boolean isRunningOnHost; private Path trustStorePath; + private DockerNetworking dockerNetworking; public Builder configServerConfig(ConfigServerConfig configServerConfig) { this.configServerConfig = configServerConfig; @@ -320,6 +321,11 @@ public class Environment { return this; } + public Builder cloud(String cloud) { + this.cloud = cloud; + return this; + } + public Builder parentHostHostname(String parentHostHostname) { this.parentHostHostname = parentHostHostname; return this; @@ -355,11 +361,6 @@ public class Environment { return this; } - public Builder cloud(String cloud) { - this.cloud = cloud; - return this; - } - public Builder certificateDnsSuffix(String certificateDnsSuffix) { this.certificateDnsSuffix = certificateDnsSuffix; return this; @@ -390,25 +391,31 @@ public class Environment { return this; } + public Builder dockerNetworking(DockerNetworking dockerNetworking) { + this.dockerNetworking = dockerNetworking; + return this; + } + public Environment build() { return new Environment(configServerConfig, trustStorePath, environment, region, system, + cloud, parentHostHostname, Optional.ofNullable(ipAddresses).orElseGet(IPAddressesImpl::new), Optional.ofNullable(pathResolver).orElseGet(PathResolver::new), logstashNodes, coredumpFeedEndpoint, nodeType, - cloud, Optional.ofNullable(containerEnvironmentResolver).orElseGet(DefaultContainerEnvironmentResolver::new), certificateDnsSuffix, ztsUri, nodeAthenzIdentity, nodeAgentCertEnabled, - isRunningOnHost); + isRunningOnHost, + Optional.ofNullable(dockerNetworking).orElseGet(() -> DockerNetworking.from(cloud, nodeType, isRunningOnHost))); } } } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepository.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepository.java index cbb714c3779..5fc82a70e80 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepository.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepository.java @@ -27,4 +27,6 @@ public interface NodeRepository { void updateNodeAttributes(String hostName, NodeAttributes nodeAttributes); void setNodeState(String hostName, Node.State nodeState); + + void scheduleReboot(String hostname); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java index f5f0fa5a3f1..7036f6852fe 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java @@ -641,6 +641,7 @@ public class NodeSpec { public Builder updateFromNodeAttributes(NodeAttributes attributes) { attributes.getDockerImage().ifPresent(this::currentDockerImage); + attributes.getCurrentOsVersion().ifPresent(this::currentOsVersion); attributes.getHardwareDivergence().ifPresent(this::hardwareDivergence); attributes.getRebootGeneration().ifPresent(this::currentRebootGeneration); attributes.getRestartGeneration().ifPresent(this::currentRestartGeneration); diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java index 3b86869e72f..7cee1730804 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java @@ -41,9 +41,8 @@ public class RealNodeRepository implements NodeRepository { .collect(Collectors.toList()); NodeMessageResponse response = configServerApi.post("/nodes/v2/node", nodesToPost, NodeMessageResponse.class); - if (!Strings.isNullOrEmpty(response.errorCode)) { - throw new NodeRepositoryException("Failed to add nodes to node-repo: " + response.message + " " + response.errorCode); - } + if (Strings.isNullOrEmpty(response.errorCode)) return; + throw new NodeRepositoryException("Failed to add nodes to node-repo: " + response.message + " " + response.errorCode); } @Override @@ -61,10 +60,8 @@ public class RealNodeRepository implements NodeRepository { try { NodeRepositoryNode nodeResponse = configServerApi.get("/nodes/v2/node/" + hostName, NodeRepositoryNode.class); - if (nodeResponse == null) { - return Optional.empty(); - } - return Optional.of(createNodeSpec(nodeResponse)); + + return Optional.ofNullable(nodeResponse).map(RealNodeRepository::createNodeSpec); } catch (HttpException.NotFoundException | HttpException.ForbiddenException e) { // Return empty on 403 in addition to 404 as it likely means we're trying to access a node that // has been deleted. When a node is deleted, the parent-child relationship no longer exists and @@ -117,9 +114,8 @@ public class RealNodeRepository implements NodeRepository { nodeRepositoryNodeFromNodeAttributes(nodeAttributes), NodeMessageResponse.class); - if (!Strings.isNullOrEmpty(response.errorCode)) { - throw new NodeRepositoryException("Unexpected message " + response.message + " " + response.errorCode); - } + if (Strings.isNullOrEmpty(response.errorCode)) return; + throw new NodeRepositoryException("Unexpected message " + response.message + " " + response.errorCode); } @Override @@ -131,9 +127,19 @@ public class RealNodeRepository implements NodeRepository { NodeMessageResponse.class); NODE_ADMIN_LOGGER.info(response.message); - if (response.errorCode == null || response.errorCode.isEmpty()) { - return; - } + if (Strings.isNullOrEmpty(response.errorCode)) return; + throw new NodeRepositoryException("Unexpected message " + response.message + " " + response.errorCode); + } + + @Override + public void scheduleReboot(String hostName) { + NodeMessageResponse response = configServerApi.post( + "/nodes/v2/command/reboot?hostname=" + hostName, + Optional.empty(), /* body */ + NodeMessageResponse.class); + NODE_ADMIN_LOGGER.info(response.message); + + if (Strings.isNullOrEmpty(response.errorCode)) return; throw new NodeRepositoryException("Unexpected message " + response.message + " " + response.errorCode); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerNetworking.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerNetworking.java new file mode 100644 index 00000000000..7678ad8169a --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerNetworking.java @@ -0,0 +1,41 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.docker;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import com.yahoo.config.provision.NodeType; + +/** + * The types of network setup for the Docker containers. + * + * @author hakon + */ +public enum DockerNetworking { + /** Each container has an associated macvlan bridge. */ + MACVLAN("vespa-macvlan"), + + /** Network Prefix-Translated networking. */ + NPT("vespa-bridge"), + + /** A host running a single container in the host network namespace. */ + HOST_NETWORK("host"); + + private final String dockerNetworkMode; + DockerNetworking(String dockerNetworkMode) { + this.dockerNetworkMode = dockerNetworkMode; + } + + public String getDockerNetworkMode() { + return dockerNetworkMode; + } + + public static DockerNetworking from(String cloud, NodeType nodeType, boolean hostAdmin) { + if (cloud.equals("AWS")) { + return DockerNetworking.NPT; + } else if (nodeType == NodeType.confighost || nodeType == NodeType.proxyhost) { + return DockerNetworking.HOST_NETWORK; + } else if (hostAdmin) { + return DockerNetworking.NPT; + } else { + return DockerNetworking.MACVLAN; + } + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java index f3b5dc9342a..a197eafe923 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java @@ -10,7 +10,6 @@ import com.yahoo.vespa.hosted.dockerapi.ContainerName; import com.yahoo.vespa.hosted.dockerapi.ContainerResources; import com.yahoo.vespa.hosted.dockerapi.Docker; import com.yahoo.vespa.hosted.dockerapi.DockerImage; -import com.yahoo.vespa.hosted.dockerapi.DockerImpl; import com.yahoo.vespa.hosted.dockerapi.DockerNetworkCreator; import com.yahoo.vespa.hosted.dockerapi.ProcessResult; import com.yahoo.vespa.hosted.node.admin.component.Environment; @@ -45,8 +44,7 @@ public class DockerOperationsImpl implements DockerOperations { private static final String IPV6_NPT_PREFIX = "fd00::"; private static final String IPV4_NPT_PREFIX = "172.17.0.0"; - private static final String DOCKER_CUSTOM_BRIDGE_NETWORK_NAME = "vespa-bridge"; - + private final Docker docker; private final Environment environment; private final ProcessExecuter processExecuter; @@ -87,34 +85,33 @@ public class DockerOperationsImpl implements DockerOperations { .withUlimit("nproc", 32_768, 409_600) .withUlimit("core", -1, -1) .withAddCapability("SYS_PTRACE") // Needed for gcore, pstack etc. - .withAddCapability("SYS_ADMIN") // Needed for perf - - // TODO: Fix. Run containers as privileged in AWS because mapped directories are on another device - .withPrivileged(environment.getCloud().equalsIgnoreCase("aws")); + .withAddCapability("SYS_ADMIN"); // Needed for perf if (environment.getNodeType() == NodeType.confighost || environment.getNodeType() == NodeType.proxyhost) { command.withVolume("/var/lib/sia", "/var/lib/sia"); } + if (environment.getNodeType() == NodeType.proxyhost) { + command.withVolume("/opt/yahoo/share/ssl/certs/", "/opt/yahoo/share/ssl/certs/"); + } + if (environment.getNodeType() == NodeType.host) { Path zpePathInNode = environment.pathInNodeUnderVespaHome("var/zpe"); if (environment.isRunningOnHost()) { - command.withVolume("/var/zpe", zpePathInNode.toString()); + command.withSharedVolume("/var/zpe", zpePathInNode.toString()); } else { command.withVolume(environment.pathInHostFromPathInNode(containerName, zpePathInNode).toString(), zpePathInNode.toString()); } } - if (environment.getNodeType() == NodeType.proxyhost) { - command.withVolume("/opt/yahoo/share/ssl/certs/", "/opt/yahoo/share/ssl/certs/"); - } + DockerNetworking networking = environment.getDockerNetworking(); + command.withNetworkMode(networking.getDockerNetworkMode()); - if (!docker.networkNATed()) { + if (networking == DockerNetworking.MACVLAN) { // TODO: Remove this if when migration to host-admin is complete command.withIpAddress(ipV6Address); - command.withNetworkMode(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME); - command.withVolume("/etc/hosts", "/etc/hosts"); - } else { + command.withSharedVolume("/etc/hosts", "/etc/hosts"); + } else if (networking == DockerNetworking.NPT) { InetAddress ipV6Prefix = InetAddresses.forString(IPV6_NPT_PREFIX); InetAddress ipV6Local = IPAddresses.prefixTranslate(ipV6Address, ipV6Prefix, 8); command.withIpAddress(ipV6Local); @@ -128,8 +125,6 @@ public class DockerOperationsImpl implements DockerOperations { ipV4Local.ifPresent(command::withIpAddress); addEtcHosts(containerData, node.getHostname(), ipV4Local, ipV6Local); - - command.withNetworkMode(DOCKER_CUSTOM_BRIDGE_NETWORK_NAME); } for (Path pathInNode : directoriesToMount.keySet()) { @@ -186,13 +181,13 @@ public class DockerOperationsImpl implements DockerOperations { PrefixLogger logger = PrefixLogger.getNodeAgentLogger(DockerOperationsImpl.class, containerName); logger.info("Starting container " + containerName); - if (!docker.networkNATed()) { + if (environment.getDockerNetworking() == DockerNetworking.MACVLAN) { docker.connectContainerToNetwork(containerName, "bridge"); } docker.startContainer(containerName); - if (!docker.networkNATed()) { + if (environment.getDockerNetworking() == DockerNetworking.MACVLAN) { setupContainerNetworkConnectivity(containerName); } @@ -368,9 +363,6 @@ public class DockerOperationsImpl implements DockerOperations { directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/db/vespa"), false); directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/jdisc_container"), false); directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/jdisc_core"), false); - if (environment.getNodeType() == NodeType.host) { - directoriesToMount.put(Paths.get("/var/lib/sia"), true); - } directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/maven"), false); directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/run"), false); directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/scoreboards"), true); @@ -385,6 +377,8 @@ public class DockerOperationsImpl implements DockerOperations { directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/container-data"), false); if (environment.getNodeType() == NodeType.proxyhost) directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/vespa-hosted/routing"), true); + if (environment.getNodeType() == NodeType.host) + directoriesToMount.put(Paths.get("/var/lib/sia"), true); return Collections.unmodifiableMap(directoriesToMount); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java index 1fe2719d2a0..cdfc8eef798 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java @@ -76,9 +76,14 @@ public class StorageMaintainer { this.coredumpHandler = Optional.ofNullable(coredumpHandler); this.clock = clock; - Dimensions dimensions = new Dimensions.Builder().add("role", "docker").build(); + Dimensions dimensions = new Dimensions.Builder() + .add("role", SecretAgentCheckConfig.nodeTypeToRole(environment.getNodeType())) + .build(); numberOfNodeAdminMaintenanceFails = metricReceiver.declareCounter(MetricReceiverWrapper.APPLICATION_DOCKER, dimensions, "nodes.maintenance.fails"); numberOfCoredumpsOnHost = metricReceiver.declareGauge(MetricReceiverWrapper.APPLICATION_DOCKER, dimensions, "nodes.coredumps"); + + metricReceiver.declareCounter(MetricReceiverWrapper.APPLICATION_DOCKER, dimensions, "nodes.running_on_host") + .add(environment.isRunningOnHost() ? 1 : 0); } public void writeMetricsConfig(ContainerName containerName, NodeSpec node) { diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/acl/AclMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/acl/AclMaintainer.java index 80a702ead1e..9259b522d17 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/acl/AclMaintainer.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/acl/AclMaintainer.java @@ -6,6 +6,7 @@ import com.yahoo.vespa.hosted.dockerapi.Container; import com.yahoo.vespa.hosted.node.admin.component.Environment; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.Acl; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeRepository; +import com.yahoo.vespa.hosted.node.admin.docker.DockerNetworking; import com.yahoo.vespa.hosted.node.admin.docker.DockerOperations; import com.yahoo.vespa.hosted.node.admin.task.util.network.IPAddresses; import com.yahoo.vespa.hosted.node.admin.task.util.network.IPVersion; @@ -51,6 +52,8 @@ public class AclMaintainer implements Runnable { private void applyRedirect(Container container, InetAddress address) { IPVersion ipVersion = IPVersion.get(address); + // Necessary to avoid the routing packets destined for the node's own public IP address + // via the bridge, which is illegal. String redirectRule = "-A OUTPUT -d " + InetAddresses.toAddrString(address) + ipVersion.singleHostCidr() + " -j REDIRECT"; IPTablesEditor.editLogOnError(dockerOperations, container.name, ipVersion, "nat", NatTableLineEditor.from(redirectRule)); } @@ -61,7 +64,7 @@ public class AclMaintainer implements Runnable { IPTablesEditor.editFlushOnError(dockerOperations, container.name, IPVersion.IPv4, "filter", FilterTableLineEditor.from(acl, IPVersion.IPv4)); // Apply redirect to the nat table - if (this.environment.getCloud().equals("AWS")) { + if (environment.getDockerNetworking() == DockerNetworking.NPT) { ipAddresses.getAddress(container.hostname, IPVersion.IPv4).ifPresent(addr -> applyRedirect(container, addr)); ipAddresses.getAddress(container.hostname, IPVersion.IPv6).ifPresent(addr -> applyRedirect(container, addr)); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java index f82047d885c..3871bb82313 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java @@ -1,6 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.maintenance.identity; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SslContextBuilder; +import com.yahoo.security.X509CertificateUtils; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; import com.yahoo.vespa.athenz.client.zts.InstanceIdentity; @@ -13,12 +18,6 @@ import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.client.DefaultIdentityDocumentClient; import com.yahoo.vespa.athenz.identityprovider.client.InstanceCsrGenerator; import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.tls.Pkcs10Csr; -import com.yahoo.vespa.athenz.tls.SslContextBuilder; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; import com.yahoo.vespa.athenz.utils.SiaUtils; import com.yahoo.vespa.hosted.dockerapi.ContainerName; import com.yahoo.vespa.hosted.node.admin.component.Environment; @@ -169,10 +168,11 @@ public class AthenzCredentialsMaintainer { return now.isAfter(expiry.minus(EXPIRY_MARGIN)); } + @SuppressWarnings("deprecation") private void registerIdentity() { KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); SignedIdentityDocument signedIdentityDocument = identityDocumentClient.getNodeIdentityDocument(hostname); - Pkcs10Csr csr = csrGenerator.generateCsr( + com.yahoo.vespa.athenz.tls.Pkcs10Csr csr = csrGenerator.generateCsr( containerIdentity, signedIdentityDocument.providerUniqueId(), signedIdentityDocument.ipAddresses(), keyPair); try (ZtsClient ztsClient = new DefaultZtsClient(ztsEndpoint, hostIdentityProvider)) { InstanceIdentity instanceIdentity = @@ -191,14 +191,15 @@ public class AthenzCredentialsMaintainer { } } + @SuppressWarnings("deprecation") private void refreshIdentity() { SignedIdentityDocument identityDocument = EntityBindingsMapper.readSignedIdentityDocumentFromFile(identityDocumentFile); KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); - Pkcs10Csr csr = csrGenerator.generateCsr(containerIdentity, identityDocument.providerUniqueId(), identityDocument.ipAddresses(), keyPair); + com.yahoo.vespa.athenz.tls.Pkcs10Csr csr = csrGenerator.generateCsr(containerIdentity, identityDocument.providerUniqueId(), identityDocument.ipAddresses(), keyPair); SSLContext containerIdentitySslContext = new SslContextBuilder() - .withKeyStore(privateKeyFile.toFile(), certificateFile.toFile()) - .withTrustStore(trustStorePath.toFile(), KeyStoreType.JKS) + .withKeyStore(privateKeyFile, certificateFile) + .withTrustStore(trustStorePath, KeyStoreType.JKS) .build(); try { try (ZtsClient ztsClient = new DefaultZtsClient(ztsEndpoint, containerIdentity, containerIdentitySslContext)) { diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdmin.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdmin.java index a0657c3d34c..16992bcb13a 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdmin.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdmin.java @@ -6,7 +6,6 @@ import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeSpec; import java.time.Duration; import java.util.List; import java.util.Map; -import java.util.Set; /** * NodeAdmin manages the life cycle of NodeAgents. diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java index 96e1461bc32..ba8a2e55587 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java @@ -157,8 +157,8 @@ public class NodeAdminImpl implements NodeAdmin { Map<String, Object> debug = new LinkedHashMap<>(); debug.put("isFrozen", isFrozen); - List<Map<String, Object>> nodeAgentDebugs = nodeAgentsByHostname.entrySet().stream() - .map(node -> node.getValue().debugInfo()).collect(Collectors.toList()); + List<Map<String, Object>> nodeAgentDebugs = nodeAgentsByHostname.values().stream() + .map(NodeAgent::debugInfo).collect(Collectors.toList()); debug.put("NodeAgents", nodeAgentDebugs); return debug; } @@ -171,7 +171,7 @@ public class NodeAdminImpl implements NodeAdmin { } catch (Throwable e) { logger.warning("Metric fetcher scheduler failed", e); } - }, 0, 55, TimeUnit.SECONDS); + }, 10, 55, TimeUnit.SECONDS); int delay = 120; // WARNING: Reducing this will increase the load on config servers. aclScheduler.scheduleWithFixedDelay(() -> { diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java index 5f2093c4719..7c84150009e 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java @@ -26,14 +26,11 @@ import com.yahoo.vespa.hosted.node.admin.maintenance.identity.AthenzCredentialsM import com.yahoo.vespa.hosted.node.admin.util.PrefixLogger; import com.yahoo.vespa.hosted.provision.Node; -import java.text.SimpleDateFormat; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; -import java.util.Date; import java.util.LinkedHashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -80,9 +77,6 @@ public class NodeAgentImpl implements NodeAgent { private final Duration timeBetweenEachConverge; private final AthenzCredentialsMaintainer athenzCredentialsMaintainer; - private final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); - private final LinkedList<String> debugMessages = new LinkedList<>(); - private int numberOfUnhandledException = 0; private Instant lastConverge; @@ -155,7 +149,7 @@ public class NodeAgentImpl implements NodeAgent { synchronized (monitor) { if (wantFrozen != frozen) { wantFrozen = frozen; - addDebugMessage(wantFrozen ? "Freezing" : "Unfreezing"); + logger.debug(wantFrozen ? "Freezing" : "Unfreezing"); signalWorkToBeDone(); } @@ -163,17 +157,6 @@ public class NodeAgentImpl implements NodeAgent { } } - private void addDebugMessage(String message) { - synchronized (debugMessages) { - while (debugMessages.size() > 1000) { - debugMessages.pop(); - } - - logger.debug(message); - debugMessages.add("[" + sdf.format(new Date()) + "] " + message); - } - } - @Override public Map<String, Object> debugInfo() { Map<String, Object> debug = new LinkedHashMap<>(); @@ -182,18 +165,13 @@ public class NodeAgentImpl implements NodeAgent { debug.put("wantFrozen", wantFrozen); debug.put("terminated", terminated); debug.put("workToDoNow", workToDoNow); - synchronized (debugMessages) { - debug.put("history", new LinkedList<>(debugMessages)); - } debug.put("nodeRepoState", lastNode.getState().name()); return debug; } @Override public void start() { - String message = "Starting with interval " + timeBetweenEachConverge.toMillis() + " ms"; - logger.info(message); - addDebugMessage(message); + logger.info("Starting with interval " + timeBetweenEachConverge.toMillis() + " ms"); loopThread.start(); @@ -213,7 +191,6 @@ public class NodeAgentImpl implements NodeAgent { @Override public void stop() { - addDebugMessage("Stopping"); filebeatRestarter.shutdown(); if (!terminated.compareAndSet(false, true)) { throw new RuntimeException("Can not re-stop a node agent."); @@ -240,7 +217,7 @@ public class NodeAgentImpl implements NodeAgent { currentFilebeatRestarter = Optional.of(filebeatRestarter.scheduleWithFixedDelay( () -> serviceRestarter.accept("filebeat"), 1, 1, TimeUnit.DAYS)); - addDebugMessage("Starting optional node program resume command"); + logger.debug("Starting optional node program resume command"); dockerOperations.resumeNode(containerName); resumeScriptRun = true; } @@ -266,8 +243,6 @@ public class NodeAgentImpl implements NodeAgent { if (!currentAttributes.equals(wantedAttributes)) { logger.info("Publishing new set of attributes to node repo: " + currentAttributes + " -> " + wantedAttributes); - addDebugMessage("Publishing new set of attributes to node repo: {" + - currentAttributes + "} -> {" + wantedAttributes + "}"); nodeRepository.updateNodeAttributes(hostname, wantedAttributes); } } @@ -386,7 +361,7 @@ public class NodeAgentImpl implements NodeAgent { synchronized (monitor) { if (!workToDoNow) { workToDoNow = true; - addDebugMessage("Signaling work to be done"); + logger.debug("Signaling work to be done"); monitor.notifyAll(); } } @@ -421,21 +396,19 @@ public class NodeAgentImpl implements NodeAgent { boolean converged = false; if (isFrozenCopy) { - addDebugMessage("tick: isFrozen"); + logger.debug("tick: isFrozen"); } else { try { converge(); converged = true; } catch (OrchestratorException e) { logger.info(e.getMessage()); - addDebugMessage(e.getMessage()); } catch (DockerException e) { numberOfUnhandledException++; logger.error("Caught a DockerException, resetting containerState to " + containerState, e); } catch (Exception e) { numberOfUnhandledException++; logger.error("Unhandled exception, ignoring.", e); - addDebugMessage(e.getMessage()); } } @@ -462,7 +435,7 @@ public class NodeAgentImpl implements NodeAgent { storageMaintainer.writeMetricsConfig(containerName, node); } - addDebugMessage("Loading new node spec: " + node.toString()); + logger.debug("Loading new node spec: " + node.toString()); lastNode = node; } @@ -484,7 +457,7 @@ public class NodeAgentImpl implements NodeAgent { scheduleDownLoadIfNeeded(node); if (isDownloadingImage()) { - addDebugMessage("Waiting for image to download " + imageBeingDownloaded.asString()); + logger.debug("Waiting for image to download " + imageBeingDownloaded.asString()); return; } container = removeContainerIfNeededUpdateContainerState(node, container); diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriter.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriter.java index 58518ae5a15..c41fd71c62c 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriter.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriter.java @@ -2,10 +2,10 @@ package com.yahoo.vespa.hosted.node.admin.task.util.file; import com.yahoo.vespa.hosted.node.admin.component.TaskContext; -import org.glassfish.jersey.internal.util.Producer; import java.nio.file.Files; import java.nio.file.Path; +import java.util.function.Supplier; /** * Write a file @@ -16,11 +16,11 @@ public class FileWriter { private final Path path; private final FileSync fileSync; private final PartialFileData.Builder fileDataBuilder = PartialFileData.builder(); - private final Producer<String> contentProducer; + private final Supplier<String> contentProducer; private boolean overwriteExistingFile = true; - public FileWriter(Path path, Producer<String> contentProducer) { + public FileWriter(Path path, Supplier<String> contentProducer) { this.path = path; this.fileSync = new FileSync(path); this.contentProducer = contentProducer; @@ -51,7 +51,7 @@ public class FileWriter { return false; } - fileDataBuilder.withContent(contentProducer.call()); + fileDataBuilder.withContent(contentProducer.get()); PartialFileData fileData = fileDataBuilder.create(); return fileSync.convergeTo(context, fileData); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/systemd/SystemCtl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/systemd/SystemCtl.java index 351856c4852..b61ebb610af 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/systemd/SystemCtl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/systemd/SystemCtl.java @@ -41,6 +41,12 @@ public class SystemCtl { this.terminal = terminal; } + public void daemonReload(TaskContext taskContext) { + terminal.newCommandLine(taskContext) + .add("systemctl", "daemon-reload") + .execute(); + } + public SystemCtlEnable enable(String unit) { return new SystemCtlEnable(unit); } public SystemCtlDisable disable(String unit) { return new SystemCtlDisable(unit); } public SystemCtlStart start(String unit) { return new SystemCtlStart(unit); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepo.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepo.java deleted file mode 100644 index 5df790f9105..00000000000 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepo.java +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.node.admin.task.util.yum; - -import com.yahoo.vespa.hosted.node.admin.component.TaskContext; -import com.yahoo.vespa.hosted.node.admin.task.util.file.FileWriter; - -import java.nio.file.FileSystem; -import java.nio.file.FileSystems; -import java.nio.file.Path; -import java.util.regex.Pattern; - -/** - * @author hakonhall - */ -public class AddYumRepo { - private static final Pattern REPOSITORY_ID_PATTERN = Pattern.compile("^[a-zA-Z0-9_-]+$"); - - private final String repositoryId; // e.g. "platform_rpms-latest" - private final String name; // e.g. "Platform RPM Latest Repo" - private final String baseurl; - private final boolean enabled; - private final FileSystem fileSystem; - - public AddYumRepo(String repositoryId, - String name, - String baseurl, - boolean enabled) { - this(repositoryId, name, baseurl, enabled, FileSystems.getDefault()); - } - - public boolean converge(TaskContext context) { - Path path = fileSystem.getPath("/etc/yum.repos.d",repositoryId + ".repo"); - - FileWriter fileWriter = new FileWriter(path, this::getRepoFileContent) - .withOwner("root") - .withGroup("root") - .withPermissions("rw-r--r--") - .onlyIfFileDoesNotAlreadyExist(); - - return fileWriter.converge(context); - } - - private String getRepoFileContent() { - return String.join("\n", - "# This file was generated by node admin", - "# Do NOT modify this file by hand", - "", - "[" + repositoryId + "]", - "name=" + name, - "baseurl=" + baseurl, - "enabled=" + (enabled ? 1 : 0), - "gpgcheck=0" - ) + "\n"; - } - - private static void validateRepositoryId(String repositoryId) { - if (!REPOSITORY_ID_PATTERN.matcher(repositoryId).matches()) { - throw new IllegalArgumentException("Invalid repository ID '" + repositoryId + "'"); - } - } - - // For testing - public AddYumRepo(String repositoryId, - String name, - String baseurl, - boolean enabled, - FileSystem fileSystem) { - this.repositoryId = repositoryId; - this.name = name; - this.baseurl = baseurl; - this.enabled = enabled; - this.fileSystem = fileSystem; - validateRepositoryId(repositoryId); - } -} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java index cb23f053086..d7a503f5dcd 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java @@ -3,35 +3,147 @@ package com.yahoo.vespa.hosted.node.admin.task.util.yum; import com.yahoo.vespa.hosted.node.admin.component.TaskContext; import com.yahoo.vespa.hosted.node.admin.task.util.process.CommandLine; +import com.yahoo.vespa.hosted.node.admin.task.util.process.CommandResult; import com.yahoo.vespa.hosted.node.admin.task.util.process.Terminal; +import java.util.ArrayList; import java.util.Arrays; -import java.util.HashSet; import java.util.List; import java.util.Optional; -import java.util.Set; +import java.util.function.Function; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; /** * @author hakonhall */ public class Yum { // Note: "(?dm)" makes newline be \n (only), and enables multiline mode where ^$ match lines with find() - private static final Pattern INSTALL_NOOP_PATTERN = Pattern.compile("(?dm)^Nothing to do$"); + private static final Pattern CHECKING_FOR_UPDATE_PATTERN = + Pattern.compile("(?dm)^Package matching [^ ]+ already installed\\. Checking for update\\.$"); + private static final Pattern NOTHING_TO_DO_PATTERN = Pattern.compile("(?dm)^Nothing to do$"); + private static final Pattern INSTALL_NOOP_PATTERN = NOTHING_TO_DO_PATTERN; private static final Pattern UPGRADE_NOOP_PATTERN = Pattern.compile("(?dm)^No packages marked for update$"); private static final Pattern REMOVE_NOOP_PATTERN = Pattern.compile("(?dm)^No Packages marked for removal$"); - private static final Pattern UNKNOWN_PACKAGE_PATTERN = Pattern.compile( "(?dm)^No package ([^ ]+) available\\.$"); + + // WARNING: These must be in the same order as the supplier below + private static final String RPM_QUERYFORMAT = Stream.of("NAME", "EPOCH", "VERSION", "RELEASE", "ARCH") + .map(formatter -> "%{" + formatter + "}") + .collect(Collectors.joining("\\n")); + private static final Function<YumPackageName.Builder, List<Function<String, YumPackageName.Builder>>> + PACKAGE_NAME_BUILDERS_GENERATOR = builder -> Arrays.asList( + builder::setName, builder::setEpoch, builder::setVersion, builder::setRelease, builder::setArchitecture); + + private final Terminal terminal; public Yum(Terminal terminal) { this.terminal = terminal; } + public Optional<YumPackageName> queryInstalled(TaskContext context, String packageName) { + CommandResult commandResult = terminal.newCommandLine(context) + .add("rpm", "-q", packageName, "--queryformat", RPM_QUERYFORMAT) + .ignoreExitCode() + .executeSilently(); + + if (commandResult.getExitCode() != 0) return Optional.empty(); + + YumPackageName.Builder builder = new YumPackageName.Builder(); + List<Function<String, YumPackageName.Builder>> builders = PACKAGE_NAME_BUILDERS_GENERATOR.apply(builder); + List<Optional<String>> lines = commandResult.mapEachLine(line -> Optional.of(line).filter(s -> !"(none)".equals(s))); + if (lines.size() != builders.size()) throw new IllegalStateException(String.format( + "Unexpected response from rpm, expected %d lines, got %d" + builders.size(), commandResult.getOutput())); + + IntStream.range(0, builders.size()).forEach(i -> lines.get(i).ifPresent(builders.get(i)::apply)); + return Optional.of(builder.build()); + } + + /** + * Lock and install, or if necessary downgrade, a package to a given version. + * + * @return false only if the package was already locked and installed at the given version (no-op) + */ + public boolean installFixedVersion(TaskContext context, YumPackageName yumPackage) { + String targetVersionLockName = yumPackage.toVersionLockName(); + + boolean alreadyLocked = terminal + .newCommandLine(context) + .add("yum", "--quiet", "versionlock", "list") + .executeSilently() + .getOutputLinesStream() + .map(YumPackageName::parseString) + .filter(Optional::isPresent) // removes garbage first lines, even with --quiet + .map(Optional::get) + .anyMatch(packageName -> { + // Ignore lines for other packages + if (packageName.getName().equals(yumPackage.getName())) { + // If existing lock doesn't exactly match the full package name, + // it means it's locked to another version and we must remove that lock. + String versionLockName = packageName.toVersionLockName(); + if (versionLockName.equals(targetVersionLockName)) { + return true; + } else { + terminal.newCommandLine(context) + .add("yum", "versionlock", "delete", versionLockName) + .execute(); + } + } + + return false; + }); + + boolean modified = false; + + if (!alreadyLocked) { + terminal.newCommandLine(context) + .add("yum", "versionlock", "add", targetVersionLockName) + .execute(); + modified = true; + } + + // The following 3 things may happen with yum install: + // 1. The package is installed or upgraded to the target version, in case we'd return + // true from converge() + // 2. The package is already installed at target version, in case + // "Nothing to do" is printed in the last line and we may return false from converge() + // 3. The package is already installed but at a later version than the target version, + // in case the last 2 lines of the output is: + // - "Package matching yakl-client-0.10-654.el7.x86_64 already installed. Checking for update." + // - "Nothing to do" + // And in case we need to downgrade and return true from converge() + + CommandLine commandLine = terminal + .newCommandLine(context) + .add("yum", "install", "--assumeyes", yumPackage.toName()); + + String output = commandLine.executeSilently().getUntrimmedOutput(); + + if (NOTHING_TO_DO_PATTERN.matcher(output).find()) { + if (CHECKING_FOR_UPDATE_PATTERN.matcher(output).find()) { + // case 3. + terminal.newCommandLine(context) + .add("yum", "downgrade", "--assumeyes", yumPackage.toName()) + .execute(); + modified = true; + } else { + // case 2. + } + } else { + // case 1. + commandLine.recordSilentExecutionAsSystemModification(); + modified = true; + } + + return modified; + } + public GenericYumCommand install(YumPackageName... packages) { return newYumCommand("install", packages, INSTALL_NOOP_PATTERN); } @@ -81,8 +193,7 @@ public class Yum { private final List<YumPackageName> packages; private final Pattern commandOutputNoopPattern; - private Optional<String> enabledRepo = Optional.empty(); - private boolean lockVersion = false; + private final List<String> enabledRepo = new ArrayList<>(); private GenericYumCommand(Terminal terminal, String yumCommand, @@ -99,62 +210,15 @@ public class Yum { } @SuppressWarnings("unchecked") - public GenericYumCommand enableRepo(String repo) { - enabledRepo = Optional.of(repo); - return this; - } - - /** - * Ensure the version of the installs are locked. - * - * <p>WARNING: In order to simplify the user interface of {@link #lockVersion()}, - * the package name specified in the command, e.g. {@link #install(String, String...)}, MUST be of - * a simple format, see {@link YumPackageName#fromString(String)}. - */ - public GenericYumCommand lockVersion() { - // Verify each package has sufficient info to form a proper version lock name. - packages.forEach(YumPackageName::toVersionLockName); - lockVersion = true; + public GenericYumCommand enableRepos(String... repos) { + enabledRepo.addAll(Arrays.asList(repos)); return this; } public boolean converge(TaskContext context) { - Set<String> packageNamesToLock = new HashSet<>(); - Set<String> fullPackageNamesToLock = new HashSet<>(); - - if (lockVersion) { - // Remove all locks for other version - - packages.forEach(packageName -> { - packageNamesToLock.add(packageName.getName()); - fullPackageNamesToLock.add(packageName.toVersionLockName()); - }); - - terminal.newCommandLine(context) - .add("yum", "--quiet", "versionlock", "list") - .executeSilently() - .getOutputLinesStream() - .map(YumPackageName::parseString) - .filter(Optional::isPresent) - .map(Optional::get) - .forEach(packageName -> { - // Ignore lines for other packages - if (packageNamesToLock.contains(packageName.getName())) { - // If existing lock doesn't exactly match the full package name, - // it means it's locked to another version and we must remove that lock. - String versionLockName = packageName.toVersionLockName(); - if (!fullPackageNamesToLock.remove(versionLockName)) { - terminal.newCommandLine(context) - .add("yum", "versionlock", "delete", versionLockName) - .execute(); - } - } - }); - } - CommandLine commandLine = terminal.newCommandLine(context); commandLine.add("yum", yumCommand, "--assumeyes"); - enabledRepo.ifPresent(repo -> commandLine.add("--enablerepo=" + repo)); + enabledRepo.forEach(repo -> commandLine.add("--enablerepo=" + repo)); commandLine.add(packages.stream().map(YumPackageName::toName).collect(Collectors.toList())); // There's no way to figure out whether a yum command would have been a no-op. @@ -167,12 +231,6 @@ public class Yum { commandLine.recordSilentExecutionAsSystemModification(); } - fullPackageNamesToLock.forEach(fullPackageName -> - terminal.newCommandLine(context) - .add("yum", "versionlock", "add", fullPackageName) - .execute()); - modifiedSystem |= !fullPackageNamesToLock.isEmpty(); - return modifiedSystem; } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumPackageName.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumPackageName.java index d894af9d378..a3a154e2175 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumPackageName.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumPackageName.java @@ -1,6 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.task.util.yum; +import com.google.common.base.Strings; + import java.util.Arrays; import java.util.Objects; import java.util.Optional; @@ -43,11 +45,11 @@ public class YumPackageName { "([a-z0-9._]*[0-9][a-z0-9._]*)$"); // rel contains at least one digit private static final Pattern NAME_PATTERN = Pattern.compile("^[a-z0-9._-]+$"); - public final Optional<String> epoch; - public final String name; - public final Optional<String> version; - public final Optional<String> release; - public final Optional<String> architecture; + private final Optional<String> epoch; + private final String name; + private final Optional<String> version; + private final Optional<String> release; + private final Optional<String> architecture; public static class Builder { private Optional<String> epoch = Optional.empty(); @@ -56,6 +58,8 @@ public class YumPackageName { private Optional<String> release = Optional.empty(); private Optional<String> architecture = Optional.empty(); + public Builder() { } + public Builder(String name) { this.name = name; } @@ -70,8 +74,8 @@ public class YumPackageName { public Builder setEpoch(String epoch) { this.epoch = Optional.of(epoch); return this; } public Builder setName(String name) { this.name = name; return this; } - public Builder setRelease(String version) { this.version = Optional.of(version); return this; } - public Builder setVersion(String release) { this.release = Optional.of(release); return this; } + public Builder setVersion(String version) { this.version = Optional.of(version); return this; } + public Builder setRelease(String release) { this.release = Optional.of(release); return this; } public Builder setArchitecture(String architecture) { this.architecture = Optional.of(architecture); return this; } public YumPackageName build() { return new YumPackageName(epoch, name, version, release, architecture); } @@ -83,6 +87,9 @@ public class YumPackageName { Optional<String> version, Optional<String> release, Optional<String> architecture) { + if (Strings.isNullOrEmpty(name)) + throw new IllegalArgumentException("name cannot be null or empty"); + this.epoch = epoch; this.name = name; this.version = version; @@ -235,6 +242,14 @@ public class YumPackageName { "*"); } + public boolean isSubsetOf(YumPackageName other) { + return Objects.equals(name, other.name) && + (!epoch.isPresent() || Objects.equals(epoch, other.epoch)) && + (!version.isPresent() || Objects.equals(version, other.version)) && + (!release.isPresent() || Objects.equals(release, other.release)) && + (!architecture.isPresent() || Objects.equals(architecture, other.architecture)); + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -249,7 +264,6 @@ public class YumPackageName { @Override public int hashCode() { - return Objects.hash(epoch, name, version, release, architecture); } } diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java index 85b62660687..da2dde18d96 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java @@ -191,4 +191,12 @@ public class RealNodeRepositoryTest { assertTrue(nodeRepositoryApi.getOptionalNode("host123-1.domain.tld").isPresent()); } + + @Test + public void testRebootScheduling() { + NodeSpec nodeSpec = nodeRepositoryApi.getNode("host5.yahoo.com"); + nodeRepositoryApi.scheduleReboot(nodeSpec.getHostname()); + NodeSpec newNodeSpec = nodeRepositoryApi.getNode(nodeSpec.getHostname()); + assertEquals(nodeSpec.getWantedRebootGeneration() + 1, newNodeSpec.getWantedRebootGeneration()); + } } diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImplTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImplTest.java index e2db75eb6fb..fa94a7ff819 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImplTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImplTest.java @@ -39,6 +39,7 @@ public class DockerOperationsImplTest { .environment("prod") .system("main") .cloud("mycloud") + .dockerNetworking(DockerNetworking.HOST_NETWORK) .build(); private final Docker docker = mock(Docker.class); private final ProcessExecuter processExecuter = mock(ProcessExecuter.class); diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerMock.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerMock.java index 9b9bb2af26c..4b4ef05593d 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerMock.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerMock.java @@ -167,6 +167,11 @@ public class DockerMock implements Docker { } @Override + public CreateContainerCommand withSharedVolume(String path, String volumePath) { + return this; + } + + @Override public CreateContainerCommand withNetworkMode(String mode) { return this; } diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java index 603ad3ebccf..15bb2825738 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java @@ -10,11 +10,11 @@ import com.yahoo.vespa.hosted.dockerapi.Docker; import com.yahoo.vespa.hosted.dockerapi.metrics.MetricReceiverWrapper; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeSpec; import com.yahoo.vespa.hosted.node.admin.config.ConfigServerConfig; +import com.yahoo.vespa.hosted.node.admin.docker.DockerNetworking; import com.yahoo.vespa.hosted.node.admin.docker.DockerOperations; import com.yahoo.vespa.hosted.node.admin.docker.DockerOperationsImpl; import com.yahoo.vespa.hosted.node.admin.maintenance.acl.AclMaintainer; import com.yahoo.vespa.hosted.node.admin.maintenance.identity.AthenzCredentialsMaintainer; -import com.yahoo.vespa.hosted.node.admin.nodeadmin.NodeAdmin; import com.yahoo.vespa.hosted.node.admin.nodeadmin.NodeAdminImpl; import com.yahoo.vespa.hosted.node.admin.nodeadmin.NodeAdminStateUpdaterImpl; import com.yahoo.vespa.hosted.node.admin.nodeagent.NodeAgent; @@ -70,8 +70,9 @@ public class DockerTester implements AutoCloseable { .region("us-east-1") .environment("prod") .system("main") - .pathResolver(new PathResolver(PATH_TO_VESPA_HOME, Paths.get("/tmp"), Paths.get("/tmp"))) .cloud("mycloud") + .pathResolver(new PathResolver(PATH_TO_VESPA_HOME, Paths.get("/tmp"), Paths.get("/tmp"))) + .dockerNetworking(DockerNetworking.HOST_NETWORK) .build(); NodeSpec hostSpec = new NodeSpec.Builder() diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/NodeRepoMock.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/NodeRepoMock.java index 50aaaf1e123..d25d79ab457 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/NodeRepoMock.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/NodeRepoMock.java @@ -74,6 +74,11 @@ public class NodeRepoMock implements NodeRepository { } } + @Override + public void scheduleReboot(String hostname) { + + } + public void updateNodeRepositoryNode(NodeSpec nodeSpec) { nodeRepositoryNodesByHostname.put(nodeSpec.getHostname(), nodeSpec); } diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/logging/FilebeatConfigProviderTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/logging/FilebeatConfigProviderTest.java index 77c0a30ae18..f418552553e 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/logging/FilebeatConfigProviderTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/logging/FilebeatConfigProviderTest.java @@ -6,6 +6,7 @@ import com.yahoo.config.provision.NodeType; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeSpec; import com.yahoo.vespa.hosted.node.admin.component.Environment; import com.yahoo.vespa.hosted.node.admin.config.ConfigServerConfig; +import com.yahoo.vespa.hosted.node.admin.docker.DockerNetworking; import com.yahoo.vespa.hosted.provision.Node; import org.junit.Test; @@ -104,6 +105,7 @@ public class FilebeatConfigProviderTest { .system(system) .logstashNodes(logstashNodes) .cloud("mycloud") + .dockerNetworking(DockerNetworking.HOST_NETWORK) .build(); } diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainerTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainerTest.java index 627517b824e..d9cce7f80a0 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainerTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainerTest.java @@ -10,6 +10,7 @@ import com.yahoo.vespa.hosted.dockerapi.ContainerName; import com.yahoo.vespa.hosted.dockerapi.metrics.MetricReceiverWrapper; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeSpec; import com.yahoo.vespa.hosted.node.admin.config.ConfigServerConfig; +import com.yahoo.vespa.hosted.node.admin.docker.DockerNetworking; import com.yahoo.vespa.hosted.node.admin.docker.DockerOperations; import com.yahoo.vespa.hosted.node.admin.component.Environment; import com.yahoo.vespa.hosted.node.admin.component.PathResolver; @@ -41,8 +42,9 @@ public class StorageMaintainerTest { .region("us-east-1") .environment("prod") .system("main") - .pathResolver(new PathResolver()) .cloud("mycloud") + .pathResolver(new PathResolver()) + .dockerNetworking(DockerNetworking.HOST_NETWORK) .coredumpFeedEndpoint("http://domain.tld/docid") .build(); private final DockerOperations docker = mock(DockerOperations.class); diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/acl/AclMaintainerTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/acl/AclMaintainerTest.java index 28e21494c01..56373dda2f8 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/acl/AclMaintainerTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/maintenance/acl/AclMaintainerTest.java @@ -9,6 +9,7 @@ import com.yahoo.vespa.hosted.dockerapi.ProcessResult; import com.yahoo.vespa.hosted.node.admin.component.Environment; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.Acl; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeRepository; +import com.yahoo.vespa.hosted.node.admin.docker.DockerNetworking; import com.yahoo.vespa.hosted.node.admin.docker.DockerOperations; import com.yahoo.vespa.hosted.node.admin.task.util.network.IPAddressesMock; import com.yahoo.vespa.hosted.node.admin.task.util.network.IPVersion; @@ -25,12 +26,12 @@ import java.util.stream.Collectors; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.anyVararg; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.anyVararg; public class AclMaintainerTest { @@ -49,11 +50,13 @@ public class AclMaintainerTest { public void before() { when(dockerOperations.getAllManagedContainers()).thenReturn(containerList); when(env.getCloud()).thenReturn("AWS"); + when(env.getDockerNetworking()).thenReturn(DockerNetworking.NPT); } @Test public void no_redirect_in_yahoo() { when(env.getCloud()).thenReturn("YAHOO"); + when(env.getDockerNetworking()).thenReturn(DockerNetworking.MACVLAN); Container container = addContainer("container1", "container1.host.com", Container.State.RUNNING); Map<String, Acl> acls = makeAcl(container.hostname, "4321", "2001::1"); diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImplTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImplTest.java index f5d4dcf4e5e..ebed20326a3 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImplTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImplTest.java @@ -16,6 +16,7 @@ import com.yahoo.vespa.hosted.dockerapi.metrics.MetricReceiverWrapper; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeSpec; import com.yahoo.vespa.hosted.node.admin.config.ConfigServerConfig; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeAttributes; +import com.yahoo.vespa.hosted.node.admin.docker.DockerNetworking; import com.yahoo.vespa.hosted.node.admin.docker.DockerOperations; import com.yahoo.vespa.hosted.node.admin.maintenance.StorageMaintainer; import com.yahoo.vespa.hosted.node.admin.maintenance.acl.AclMaintainer; @@ -88,9 +89,10 @@ public class NodeAgentImplTest { .environment("dev") .region("us-east-1") .system("main") + .cloud("mycloud") .parentHostHostname("parent.host.name.yahoo.com") .pathResolver(pathResolver) - .cloud("mycloud") + .dockerNetworking(DockerNetworking.HOST_NETWORK) .build(); private final NodeSpec.Builder nodeBuilder = new NodeSpec.Builder() diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepoTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepoTest.java deleted file mode 100644 index c6314439003..00000000000 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepoTest.java +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.vespa.hosted.node.admin.task.util.yum; - -import com.yahoo.vespa.hosted.node.admin.component.TaskContext; -import com.yahoo.vespa.hosted.node.admin.task.util.file.UnixPath; -import com.yahoo.vespa.test.file.TestFileSystem; -import org.junit.Test; - -import java.nio.file.FileSystem; -import java.time.Instant; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; - -public class AddYumRepoTest { - @Test - public void converge() { - String repositoryId = "repoid"; - String name = "name"; - String baseurl = "http://foo.com/bar"; - boolean enabled = true; - - FileSystem fileSystem = TestFileSystem.create(); - AddYumRepo addYumRepo = new AddYumRepo( - repositoryId, - name, - baseurl, - enabled, - fileSystem); - - TaskContext context = mock(TaskContext.class); - - assertTrue(addYumRepo.converge(context)); - - UnixPath unixPath = new UnixPath(fileSystem.getPath("/etc/yum.repos.d/" + repositoryId + ".repo")); - String content = unixPath.readUtf8File(); - assertEquals("# This file was generated by node admin\n" + - "# Do NOT modify this file by hand\n" + - "\n" + - "[repoid]\n" + - "name=name\n" + - "baseurl=http://foo.com/bar\n" + - "enabled=1\n" + - "gpgcheck=0\n", content); - Instant lastModifiedTime = unixPath.getLastModifiedTime(); - - // Second time is a no-op - assertFalse(addYumRepo.converge(context)); - assertEquals(lastModifiedTime, unixPath.getLastModifiedTime()); - } - -}
\ No newline at end of file diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumPackageNameTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumPackageNameTest.java index 2e1ef4c0a61..01664f5c22b 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumPackageNameTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumPackageNameTest.java @@ -9,9 +9,20 @@ import static org.hamcrest.CoreMatchers.containsStringIgnoringCase; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; public class YumPackageNameTest { + @Test + public void testBuilder() { + YumPackageName yumPackage = new YumPackageName.Builder("docker") + .setEpoch("2") + .setVersion("1.12.6") + .setRelease("71.git3e8e77d.el7.centos.1") + .setArchitecture("x86_64") + .build(); + assertEquals("2:docker-1.12.6-71.git3e8e77d.el7.centos.1.x86_64", yumPackage.toName()); + } @Test public void testAllValidFormats() { @@ -139,4 +150,22 @@ public class YumPackageNameTest { assertThat(e.getMessage(), containsStringIgnoringCase("epoch")); } } + + @Test + public void testSubset() { + YumPackageName yumPackage = new YumPackageName.Builder("docker") + .setVersion("1.12.6") + .build(); + + assertTrue(yumPackage.isSubsetOf(yumPackage)); + assertTrue(yumPackage.isSubsetOf(new YumPackageName.Builder("docker") + .setVersion("1.12.6") + .setEpoch("2") + .setRelease("71.git3e8e77d.el7.centos.1") + .setArchitecture("x86_64") + .build())); + assertFalse(yumPackage.isSubsetOf(new YumPackageName.Builder("docker") + .setVersion("1.13.1") + .build())); + } }
\ No newline at end of file diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java index a635dd6a44d..c7e2885a907 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java @@ -4,15 +4,15 @@ package com.yahoo.vespa.hosted.node.admin.task.util.yum; import com.yahoo.vespa.hosted.node.admin.component.TaskContext; import com.yahoo.vespa.hosted.node.admin.task.util.process.ChildProcessFailureException; import com.yahoo.vespa.hosted.node.admin.task.util.process.TestTerminal; -import org.hamcrest.CoreMatchers; import org.junit.After; import org.junit.Test; +import java.util.Optional; + import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; @@ -28,6 +28,52 @@ public class YumTest { } @Test + public void testQueryInstalledNevra() { + terminal.expectCommand( + "rpm -q docker --queryformat \"%{NAME}\\\\n%{EPOCH}\\\\n%{VERSION}\\\\n%{RELEASE}\\\\n%{ARCH}\" 2>&1", + 0, + "docker\n2\n1.13.1\n74.git6e3bb8e.el7.centos\nx86_64"); + + Optional<YumPackageName> installed = yum.queryInstalled(taskContext, "docker"); + + assertTrue(installed.isPresent()); + assertEquals("docker", installed.get().getName()); + assertEquals("2", installed.get().getEpoch().get()); + assertEquals("1.13.1", installed.get().getVersion().get()); + assertEquals("74.git6e3bb8e.el7.centos", installed.get().getRelease().get()); + assertEquals("x86_64", installed.get().getArchitecture().get()); + } + + @Test + public void testQueryInstalledPartial() { + terminal.expectCommand( + "rpm -q vespa-node-admin --queryformat \"%{NAME}\\\\n%{EPOCH}\\\\n%{VERSION}\\\\n%{RELEASE}\\\\n%{ARCH}\" 2>&1", + 0, + "vespa-node-admin\n(none)\n6.283.62\n1.el7\nnoarch"); + + Optional<YumPackageName> installed = yum.queryInstalled(taskContext, "vespa-node-admin"); + + assertTrue(installed.isPresent()); + assertEquals("vespa-node-admin", installed.get().getName()); + assertFalse(installed.get().getEpoch().isPresent()); + assertEquals("6.283.62", installed.get().getVersion().get()); + assertEquals("1.el7", installed.get().getRelease().get()); + assertEquals("noarch", installed.get().getArchitecture().get()); + } + + @Test + public void testQueryNotInstalled() { + terminal.expectCommand( + "rpm -q fake-package --queryformat \"%{NAME}\\\\n%{EPOCH}\\\\n%{VERSION}\\\\n%{RELEASE}\\\\n%{ARCH}\" 2>&1", + 1, + "package fake-package is not installed"); + + Optional<YumPackageName> installed = yum.queryInstalled(taskContext, "fake-package"); + + assertFalse(installed.isPresent()); + } + + @Test public void testArrayConversion() { YumPackageName[] expected = new YumPackageName[] { new YumPackageName.Builder("1").build() }; assertArrayEquals(expected, Yum.toYumPackageNameArray("1")); @@ -49,13 +95,13 @@ public class YumTest { @Test public void testAlreadyInstalled() { terminal.expectCommand( - "yum install --assumeyes --enablerepo=repo-name package-1 package-2 2>&1", + "yum install --assumeyes --enablerepo=repo1 --enablerepo=repo2 package-1 package-2 2>&1", 0, "foobar\nNothing to do\n"); assertFalse(yum .install("package-1", "package-2") - .enableRepo("repo-name") + .enableRepos("repo1", "repo2") .converge(taskContext)); } @@ -104,7 +150,7 @@ public class YumTest { assertTrue(yum .install("package-1", "package-2") - .enableRepo("repo-name") + .enableRepos("repo-name") .converge(taskContext)); } @@ -114,16 +160,14 @@ public class YumTest { 0, "Repository chef_rpms-release is listed more than once in the configuration\n" + "0:chef-12.21.1-1.el7.*\n"); + terminal.expectCommand("yum versionlock add \"0:package-1-0.10-654.el7.*\" 2>&1"); terminal.expectCommand( - "yum install --assumeyes \"0:package-1-0.10-654.el7.*\" 2>&1", + "yum install --assumeyes 0:package-1-0.10-654.el7.x86_64 2>&1", 0, "installing"); - terminal.expectCommand("yum versionlock add \"0:package-1-0.10-654.el7.*\" 2>&1"); - assertTrue(yum - .install("0:package-1-0.10-654.el7.*") - .lockVersion() - .converge(taskContext)); + assertTrue(yum.installFixedVersion(taskContext, + YumPackageName.fromString("0:package-1-0.10-654.el7.x86_64"))); } @Test @@ -136,17 +180,15 @@ public class YumTest { terminal.expectCommand("yum versionlock delete \"0:package-1-0.1-8.el7.*\" 2>&1"); + terminal.expectCommand("yum versionlock add \"0:package-1-0.10-654.el7.*\" 2>&1"); + terminal.expectCommand( - "yum install --assumeyes \"0:package-1-0.10-654.el7.*\" 2>&1", + "yum install --assumeyes 0:package-1-0.10-654.el7 2>&1", 0, "Nothing to do\n"); - terminal.expectCommand("yum versionlock add \"0:package-1-0.10-654.el7.*\" 2>&1"); - assertTrue(yum - .install("0:package-1-0.10-654.el7.*") - .lockVersion() - .converge(taskContext)); + assertTrue(yum.installFixedVersion(taskContext, YumPackageName.fromString("0:package-1-0.10-654.el7"))); } @Test @@ -157,24 +199,30 @@ public class YumTest { "0:chef-12.21.1-1.el7.*\n" + "0:package-1-0.10-654.el7.*\n"); terminal.expectCommand( - "yum install --assumeyes \"0:package-1-0.10-654.el7.*\" 2>&1", + "yum install --assumeyes 0:package-1-0.10-654.el7 2>&1", 0, "Nothing to do\n"); - assertFalse(yum - .install("0:package-1-0.10-654.el7.*") - .lockVersion() - .converge(taskContext)); + assertFalse(yum.installFixedVersion(taskContext, YumPackageName.fromString("0:package-1-0.10-654.el7"))); } @Test - public void testBadPackageNameWithLock() { - try { - yum.install("package-1-0.10-654.el7").lockVersion(); - fail(); - } catch (IllegalStateException e) { - assertThat(e.getMessage(), CoreMatchers.containsStringIgnoringCase("epoch is missing")); - } + public void testWithDowngrade() { + terminal.expectCommand("yum --quiet versionlock list 2>&1", + 0, + "Repository chef_rpms-release is listed more than once in the configuration\n" + + "0:chef-12.21.1-1.el7.*\n" + + "0:package-1-0.10-654.el7.*\n"); + + terminal.expectCommand( + "yum install --assumeyes 0:package-1-0.10-654.el7 2>&1", + 0, + "Package matching package-1-0.10-654.el7 already installed. Checking for update.\n" + + "Nothing to do\n"); + + terminal.expectCommand("yum downgrade --assumeyes 0:package-1-0.10-654.el7 2>&1"); + + assertTrue(yum.installFixedVersion(taskContext, YumPackageName.fromString("0:package-1-0.10-654.el7"))); } @Test(expected = ChildProcessFailureException.class) @@ -185,7 +233,7 @@ public class YumTest { "error"); yum.install("package-1", "package-2") - .enableRepo("repo-name") + .enableRepos("repo-name") .converge(taskContext); fail(); } diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/EnvironmentTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/EnvironmentTest.java index a3a455605ad..893607f1806 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/EnvironmentTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/EnvironmentTest.java @@ -6,6 +6,7 @@ import com.yahoo.vespa.hosted.dockerapi.ContainerName; import com.yahoo.vespa.hosted.node.admin.component.Environment; import com.yahoo.vespa.hosted.node.admin.component.PathResolver; import com.yahoo.vespa.hosted.node.admin.config.ConfigServerConfig; +import com.yahoo.vespa.hosted.node.admin.docker.DockerNetworking; import org.junit.Test; import java.nio.file.Path; @@ -22,8 +23,9 @@ public class EnvironmentTest { .region("us-east-1") .environment("prod") .system("main") - .pathResolver(new PathResolver()) .cloud("mycloud") + .pathResolver(new PathResolver()) + .dockerNetworking(DockerNetworking.HOST_NETWORK) .build(); @Test diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java index 69b31f506e5..e03250e7934 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java @@ -183,8 +183,10 @@ public class NodeRepository extends AbstractComponent { // For all cases below, trust: // - nodes in same application // - config servers + // - ssh node.allocation().ifPresent(allocation -> trustedNodes.addAll(candidates.owner(allocation.owner()).asList())); trustedNodes.addAll(candidates.nodeType(NodeType.config).asList()); + trustedPorts.add(22); switch (node.type()) { case tenant: diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java index efa1cd3745d..6168d6fcf78 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java @@ -48,6 +48,8 @@ public class CapacityPolicies { return flavors.getFlavorOrThrow(requestedFlavor.get()); String defaultFlavorName = zone.defaultFlavor(cluster.type()); + if (zone.system() == SystemName.cd) + return flavors.getFlavorOrThrow(requestedFlavor.orElse(defaultFlavorName)); switch(zone.environment()) { case dev : case test : case staging : return flavors.getFlavorOrThrow(defaultFlavorName); default : return flavors.getFlavorOrThrow(requestedFlavor.orElse(defaultFlavorName)); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java index 0ef5c03e543..06a86cbddf7 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java @@ -81,7 +81,8 @@ public class NodeRepositoryProvisioner implements Provisioner { int effectiveGroups; NodeSpec requestedNodes; if ( requestedCapacity.type() == NodeType.tenant) { - int nodeCount = capacityPolicies.decideSize(requestedCapacity); + int nodeCount = application.instance().isTester() ? 1 : capacityPolicies.decideSize(requestedCapacity); + if (zone.environment().isManuallyDeployed() && nodeCount < requestedCapacity.nodeCount()) logger.log(Level.INFO, "Requested " + requestedCapacity.nodeCount() + " nodes for " + cluster + ", downscaling to " + nodeCount + " nodes in " + zone.environment()); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifier.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifier.java index 90c24f6bb23..0891279f30c 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifier.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifier.java @@ -6,8 +6,8 @@ import com.google.common.base.Suppliers; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.Zone; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; -import com.yahoo.vespa.athenz.tls.SubjectAlternativeName; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; +import com.yahoo.security.SubjectAlternativeName; +import com.yahoo.security.X509CertificateUtils; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeRepository; @@ -16,7 +16,7 @@ import java.util.List; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; -import static com.yahoo.vespa.athenz.tls.SubjectAlternativeName.Type.DNS_NAME; +import static com.yahoo.security.SubjectAlternativeName.Type.DNS_NAME; /** * Resolve node from various types of x509 identity certificates. diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/NodeFailerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/NodeFailerTest.java index 71b0b125e0f..08cf8e7dc20 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/NodeFailerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/NodeFailerTest.java @@ -252,7 +252,9 @@ public class NodeFailerTest { @Test public void docker_host_failed_without_config_requests() { - NodeFailTester tester = NodeFailTester.withTwoApplications(); + NodeFailTester tester = NodeFailTester.withTwoApplications( + new ConfigserverConfig(new ConfigserverConfig.Builder().nodeAdminInContainer(true)) + ); // For a day all nodes work so nothing happens for (int minutes = 0, interval = 30; minutes < 24 * 60; minutes += interval) { diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/FilterTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/FilterTester.java index 6420a5237e8..caecce1634d 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/FilterTester.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/FilterTester.java @@ -5,9 +5,12 @@ import com.yahoo.application.container.handler.Request.Method; import com.yahoo.container.jdisc.RequestHandlerTestDriver; import com.yahoo.jdisc.http.filter.DiscFilterRequest; import com.yahoo.jdisc.http.filter.SecurityRequestFilter; -import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateBuilder; import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; import java.net.URI; import java.security.KeyPair; import java.security.KeyPairGenerator; @@ -20,7 +23,8 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_ECDSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_RSA; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -65,7 +69,7 @@ public class FilterTester { when(r.getRemoteAddr()).thenReturn(request.remoteAddr()); when(r.getLocalAddr()).thenReturn(request.localAddr()); if (request.commonName().isPresent()) { - X509Certificate cert = certificateFor(request.commonName().get(), keyPair()); + X509Certificate cert = certificateFor(request.commonName().get(), KeyUtils.generateKeypair(KeyAlgorithm.EC)); List<X509Certificate> certs = Collections.singletonList(cert); when(r.getClientCertificateChain()).thenReturn(certs); when(r.getUserPrincipal()).thenReturn(NodePrincipal.withLegacyIdentity(request.commonName().get(), certs)); @@ -73,23 +77,13 @@ public class FilterTester { return r; } - /** Create a RSA public/private key pair */ - private static KeyPair keyPair() { - try { - KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); - keyGen.initialize(2048); - return keyGen.generateKeyPair(); - } catch (NoSuchAlgorithmException e) { - throw new RuntimeException(e); - } - } /** Create a self signed certificate for commonName using given public/private key pair */ private static X509Certificate certificateFor(String commonName, KeyPair keyPair) { Instant now = Instant.now(); X500Principal subject = new X500Principal("CN=" + commonName); return X509CertificateBuilder - .fromKeypair(keyPair, subject, now, now.plus(Duration.ofDays(30)), SHA256_WITH_RSA, now.toEpochMilli()) + .fromKeypair(keyPair, subject, now, now.plus(Duration.ofDays(30)), SHA256_WITH_ECDSA, BigInteger.valueOf(now.toEpochMilli())) .setBasicConstraints(true, true) .build(); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java index d02a666eb69..f7d4a9603e7 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java @@ -12,10 +12,10 @@ import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.Zone; import com.yahoo.config.provisioning.FlavorsConfig; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.tls.Pkcs10Csr; -import com.yahoo.vespa.athenz.tls.Pkcs10CsrBuilder; -import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.Pkcs10Csr; +import com.yahoo.security.Pkcs10CsrBuilder; +import com.yahoo.security.X509CertificateBuilder; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeRepositoryTester; import com.yahoo.vespa.hosted.provision.node.Allocation; @@ -26,14 +26,17 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; import java.security.KeyPair; import java.security.cert.X509Certificate; import java.time.Instant; import java.util.Optional; +import static com.yahoo.security.KeyAlgorithm.EC; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_ECDSA; import static com.yahoo.vespa.athenz.identityprovider.api.IdentityType.*; -import static com.yahoo.vespa.athenz.tls.KeyAlgorithm.RSA; -import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; +import static com.yahoo.security.KeyAlgorithm.RSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_RSA; import static com.yahoo.vespa.hosted.provision.restapi.v2.filter.NodeIdentifier.CONFIGSERVER_HOST_IDENTITY; import static com.yahoo.vespa.hosted.provision.restapi.v2.filter.NodeIdentifier.PROXY_HOST_IDENTITY; import static com.yahoo.vespa.hosted.provision.restapi.v2.filter.NodeIdentifier.TENANT_DOCKER_CONTAINER_IDENTITY; @@ -64,7 +67,7 @@ public class NodeIdentifierTest { private static final String INSTANCE_ID = "default"; private static final Zone ZONE = new Zone(SystemName.main, Environment.prod, RegionName.defaultName()); - private static final KeyPair KEYPAIR = KeyUtils.generateKeypair(RSA); + private static final KeyPair KEYPAIR = KeyUtils.generateKeypair(EC); private static final X509Certificate ATHENZ_YAHOO_CA_CERT = createDummyCaCertificate("Yahoo Athenz CA"); private static final X509Certificate ATHENZ_AWS_CA_CERT = createDummyCaCertificate("Athenz AWS CA"); @@ -73,7 +76,7 @@ public class NodeIdentifierTest { NodeRepositoryTester nodeRepositoryDummy = new NodeRepositoryTester(); X509Certificate certificate = X509CertificateBuilder .fromKeypair( - KEYPAIR, new X500Principal("CN=" + HOSTNAME), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_RSA, 1) + KEYPAIR, new X500Principal("CN=" + HOSTNAME), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_ECDSA, BigInteger.ONE) .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); expectedException.expect(NodeIdentifier.NodeIdentifierException.class); @@ -87,10 +90,10 @@ public class NodeIdentifierTest { nodeRepositoryDummy.addNode(OPENSTACK_ID, HOSTNAME, INSTANCE_ID, NodeType.host); nodeRepositoryDummy.setNodeState(HOSTNAME, Node.State.active); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(OPENSTACK_ID + ".instanceid.athenz.provider-name.ostk.yahoo.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -106,10 +109,10 @@ public class NodeIdentifierTest { nodeRepositoryDummy.addNode(AWS_INSTANCE_ID, HOSTNAME, INSTANCE_ID, NodeType.host); nodeRepositoryDummy.setNodeState(HOSTNAME, Node.State.active); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(AWS_INSTANCE_ID + ".instanceid.athenz.aws.oath.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -125,10 +128,10 @@ public class NodeIdentifierTest { nodeRepositoryDummy.addNode(AWS_INSTANCE_ID, PROXY_HOSTNAME, INSTANCE_ID, NodeType.proxyhost); nodeRepositoryDummy.setNodeState(PROXY_HOSTNAME, Node.State.active); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + PROXY_HOST_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + PROXY_HOST_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(AWS_INSTANCE_ID + ".instanceid.athenz.aws.oath.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -142,10 +145,10 @@ public class NodeIdentifierTest { public void accepts_aws_configserver_host_certificate() { NodeRepositoryTester nodeRepositoryDummy = new NodeRepositoryTester(); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + CONFIGSERVER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + CONFIGSERVER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(AWS_INSTANCE_ID + ".instanceid.athenz.aws.oath.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -156,7 +159,7 @@ public class NodeIdentifierTest { @Test public void accepts_zts_certificate() { X509Certificate certificate = X509CertificateBuilder - .fromKeypair(KEYPAIR, new X500Principal("CN=" + ZTS_AWS_IDENTITY), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_RSA, 1) + .fromKeypair(KEYPAIR, new X500Principal("CN=" + ZTS_AWS_IDENTITY), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_ECDSA, BigInteger.ONE) .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, new NodeRepositoryTester().nodeRepository()); NodePrincipal identity = identifier.resolveNode(singletonList(certificate)); @@ -176,11 +179,11 @@ public class NodeIdentifierTest { Node node = createNode(clusterId, clusterIndex, tenant, application); nodeRepositoryDummy.nodeRepository().addDockerNodes(singletonList(node)); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_CONTAINER_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_CONTAINER_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); VespaUniqueInstanceId vespaUniqueInstanceId = new VespaUniqueInstanceId(clusterIndex, clusterId, INSTANCE_ID, application, tenant, region, environment, NODE); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(vespaUniqueInstanceId.asDottedString() + ".instanceid.athenz.provider-name.vespa.yahoo.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -194,10 +197,10 @@ public class NodeIdentifierTest { public void accepts_controller_certificate() { NodeRepositoryTester nodeRepositoryDummy = new NodeRepositoryTester(); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + CONTROLLER_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + CONTROLLER_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); NodePrincipal identity = identifier.resolveNode(singletonList(certificate)); @@ -211,10 +214,10 @@ public class NodeIdentifierTest { nodeRepositoryDummy.addNode(OPENSTACK_ID, HOSTNAME, INSTANCE_ID, NodeType.tenant); nodeRepositoryDummy.setNodeState(HOSTNAME, Node.State.active); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_CONTAINER_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_CONTAINER_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(OPENSTACK_ID + ".instanceid.athenz.ostk.yahoo.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -251,10 +254,10 @@ public class NodeIdentifierTest { } private static X509Certificate createDummyCaCertificate(String caCommonName) { - KeyPair keyPair = KeyUtils.generateKeypair(RSA); + KeyPair keyPair = KeyUtils.generateKeypair(EC); return X509CertificateBuilder .fromKeypair( - keyPair, new X500Principal("CN=" + caCommonName), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_RSA, 1) + keyPair, new X500Principal("CN=" + caCommonName), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_ECDSA, BigInteger.ONE) .setBasicConstraints(true, true) .build(); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-config-server.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-config-server.json index d1dc7e22fcd..e599d5a7b0b 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-config-server.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-config-server.json @@ -190,6 +190,10 @@ "trustedNetworks": [], "trustedPorts": [ { + "port": 22, + "trustedBy": "cfg1.yahoo.com" + }, + { "port": 4443, "trustedBy": "cfg1.yahoo.com" } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-docker-host.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-docker-host.json index 7b1af066065..4d6607bd1b0 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-docker-host.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-docker-host.json @@ -79,5 +79,10 @@ "trustedBy": "dockerhost1.yahoo.com" } ], - "trustedPorts": [] + "trustedPorts": [ + { + "port": 22, + "trustedBy": "dockerhost1.yahoo.com" + } + ] } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-tenant-node.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-tenant-node.json index e3ea3b62bec..4e1ba2271d9 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-tenant-node.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/responses/acl-tenant-node.json @@ -134,5 +134,10 @@ } ], "trustedNetworks": [], - "trustedPorts":[] + "trustedPorts": [ + { + "port": 22, + "trustedBy": "foo.yahoo.com" + } + ] } diff --git a/parent/pom.xml b/parent/pom.xml index 38e19ea9807..e6ad2b18df9 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -689,7 +689,7 @@ find zkfacade/src/main/java/org/apache/curator -name package-info.java | \ xargs perl -pi -e 's/major = [0-9]+, minor = [0-9]+, micro = [0-9]+/major = 2, minor = 9, micro = 1/g' --> - <curator.version>2.12.0</curator.version> + <curator.version>2.9.1</curator.version> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding> <test.hide>true</test.hide> @@ -1,5 +1,5 @@ <?xml version="1.0" encoding="UTF-8"?> -<!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>com.yahoo.vespa</groupId> diff --git a/searchcommon/src/vespa/searchcommon/attribute/attributecontent.h b/searchcommon/src/vespa/searchcommon/attribute/attributecontent.h index 508a2d04c27..72ce1754d71 100644 --- a/searchcommon/src/vespa/searchcommon/attribute/attributecontent.h +++ b/searchcommon/src/vespa/searchcommon/attribute/attributecontent.h @@ -145,7 +145,7 @@ public: search::attribute::IAttributeVector::DocId docId) { uint32_t count = attribute.get(docId, data(), capacity()); - if (count > capacity()) { + while (count > capacity()) { allocate(count); count = attribute.get(docId, data(), capacity()); } diff --git a/searchcore/pom.xml b/searchcore/pom.xml index 3b43bf1205e..448209b1fd6 100644 --- a/searchcore/pom.xml +++ b/searchcore/pom.xml @@ -1,3 +1,4 @@ +<?xml version="1.0"?> <!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" @@ -11,14 +12,15 @@ <relativePath>../parent/pom.xml</relativePath> </parent> <artifactId>searchcore</artifactId> - <version>6-SNAPSHOT</version> <packaging>jar</packaging> + <version>6-SNAPSHOT</version> <name>${project.artifactId}</name> <dependencies> <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>config-lib</artifactId> <version>${project.version}</version> + <scope>provided</scope> </dependency> </dependencies> <build> diff --git a/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp b/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp index 8ffe4807427..7af7951abba 100644 --- a/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp +++ b/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp @@ -192,8 +192,6 @@ public: virtual void replay(const SplitBucketOperation &op) override { print(op); } virtual void replay(const JoinBucketsOperation &op) override { print(op); } virtual void replay(const PruneRemovedDocumentsOperation &op) override { print(op); } - virtual void replay(const SpoolerReplayStartOperation &op) override { print(op); } - virtual void replay(const SpoolerReplayCompleteOperation &op) override { print(op); } virtual void replay(const MoveOperation &op) override { print(op); } virtual void replay(const CreateBucketOperation &op) override { print(op); } virtual void replay(const CompactLidSpaceOperation &op) override { print(op); } @@ -275,8 +273,6 @@ public: virtual void replay(const SplitBucketOperation &) override { } virtual void replay(const JoinBucketsOperation &) override { } virtual void replay(const PruneRemovedDocumentsOperation &) override { } - virtual void replay(const SpoolerReplayStartOperation &) override { } - virtual void replay(const SpoolerReplayCompleteOperation &) override { } virtual void replay(const MoveOperation &) override { } virtual void replay(const CreateBucketOperation &) override { } }; diff --git a/searchcore/src/main/java/com/yahoo/vespa/config/search/core/package-info.java b/searchcore/src/main/java/com/yahoo/vespa/config/search/core/package-info.java new file mode 100644 index 00000000000..c29162d65ae --- /dev/null +++ b/searchcore/src/main/java/com/yahoo/vespa/config/search/core/package-info.java @@ -0,0 +1,7 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +@ExportPackage +package com.yahoo.vespa.config.search.core; + +import com.yahoo.osgi.annotation.ExportPackage; + diff --git a/searchcore/src/tests/proton/feedoperation/feedoperation_test.cpp b/searchcore/src/tests/proton/feedoperation/feedoperation_test.cpp index e87e9209a17..6a9dc42b56d 100644 --- a/searchcore/src/tests/proton/feedoperation/feedoperation_test.cpp +++ b/searchcore/src/tests/proton/feedoperation/feedoperation_test.cpp @@ -12,7 +12,6 @@ #include <vespa/searchcore/proton/feedoperation/putoperation.h> #include <vespa/searchcore/proton/feedoperation/removeoperation.h> #include <vespa/searchcore/proton/feedoperation/splitbucketoperation.h> -#include <vespa/searchcore/proton/feedoperation/spoolerreplayoperation.h> #include <vespa/searchcore/proton/feedoperation/updateoperation.h> #include <vespa/searchcore/proton/feedoperation/wipehistoryoperation.h> #include <vespa/searchlib/query/base.h> @@ -212,17 +211,6 @@ TEST("require that toString() on derived classes are meaningful") "target2=BucketId(0x000000000000002c), serialNum=0)", SplitBucketOperation(bucket_id1, bucket_id2, bucket_id3) .toString()); - - EXPECT_EQUAL("SpoolerReplayStart(spoolerSerialNum=0, serialNum=0)", - SpoolerReplayStartOperation().toString()); - EXPECT_EQUAL("SpoolerReplayStart(spoolerSerialNum=20, serialNum=10)", - SpoolerReplayStartOperation(10, 20).toString()); - - EXPECT_EQUAL("SpoolerReplayComplete(spoolerSerialNum=0, serialNum=0)", - SpoolerReplayCompleteOperation().toString()); - EXPECT_EQUAL("SpoolerReplayComplete(spoolerSerialNum=2, serialNum=1)", - SpoolerReplayCompleteOperation(1, 2).toString()); - EXPECT_EQUAL("Update(NULL, BucketId(0x0000000000000000), timestamp=0, dbdId=(subDbId=0, lid=0), " "prevDbdId=(subDbId=0, lid=0), prevMarkedAsRemoved=false, prevTimestamp=0, serialNum=0)", UpdateOperation().toString()); diff --git a/searchcore/src/tests/proton/flushengine/CMakeLists.txt b/searchcore/src/tests/proton/flushengine/CMakeLists.txt index 826c9b2390f..6e8df3c9b7f 100644 --- a/searchcore/src/tests/proton/flushengine/CMakeLists.txt +++ b/searchcore/src/tests/proton/flushengine/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_executable(searchcore_flushengine_test_app TEST SOURCES - flushengine.cpp + flushengine_test.cpp DEPENDS searchcore_flushengine searchcore_pcommon diff --git a/searchcore/src/tests/proton/flushengine/flushengine.cpp b/searchcore/src/tests/proton/flushengine/flushengine_test.cpp index d1a98f1b7d3..f668072b9fd 100644 --- a/searchcore/src/tests/proton/flushengine/flushengine.cpp +++ b/searchcore/src/tests/proton/flushengine/flushengine_test.cpp @@ -3,15 +3,15 @@ #include <vespa/searchcore/proton/flushengine/cachedflushtarget.h> #include <vespa/searchcore/proton/flushengine/flush_engine_explorer.h> #include <vespa/searchcore/proton/flushengine/flushengine.h> +#include <vespa/searchcore/proton/flushengine/i_tls_stats_factory.h> #include <vespa/searchcore/proton/flushengine/threadedflushtarget.h> #include <vespa/searchcore/proton/flushengine/tls_stats_map.h> -#include <vespa/searchcore/proton/flushengine/i_tls_stats_factory.h> #include <vespa/searchcore/proton/server/igetserialnum.h> #include <vespa/searchcore/proton/test/dummy_flush_handler.h> #include <vespa/searchcore/proton/test/dummy_flush_target.h> -#include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/test/insertion_operators.h> +#include <vespa/vespalib/testkit/testapp.h> #include <mutex> #include <chrono> @@ -42,7 +42,6 @@ public: SimpleExecutor() : _done() { - // empty } Task::UP @@ -113,7 +112,7 @@ public: } }; -typedef std::vector<IFlushTarget::SP> Targets; +using Targets = std::vector<IFlushTarget::SP>; using FlushDoneHistory = std::vector<search::SerialNum>; @@ -141,7 +140,6 @@ public: _done(targets.size()), _flushDoneHistory() { - // empty } search::SerialNum @@ -219,7 +217,6 @@ public: : _flushedSerial(flushedSerial), _currentSerial(currentSerial), _start(start), _done(done), _proceed(proceed) { - // empty } void run() override { @@ -248,39 +245,45 @@ public: vespalib::Gate _taskDone; Task::UP _task; -public: - typedef std::shared_ptr<SimpleTarget> SP; - - SimpleTarget(Task::UP task, const std::string &name) : - test::DummyFlushTarget(name), - _flushedSerial(0), - _currentSerial(0), +protected: + SimpleTarget(const std::string &name, const Type &type, search::SerialNum flushedSerial = 0, bool proceedImmediately = true) : + test::DummyFlushTarget(name, type, Component::OTHER), + _flushedSerial(flushedSerial), _proceed(), _initDone(), _taskStart(), _taskDone(), - _task(std::move(task)) + _task(std::make_unique<SimpleTask>(_taskStart, _taskDone, &_proceed, + _flushedSerial, _currentSerial)) { + if (proceedImmediately) { + _proceed.countDown(); + } } - SimpleTarget(const std::string &name, search::SerialNum flushedSerial = 0, bool proceedImmediately = true) : +public: + using SP = std::shared_ptr<SimpleTarget>; + + SimpleTarget(Task::UP task, const std::string &name) : test::DummyFlushTarget(name), - _flushedSerial(flushedSerial), + _flushedSerial(0), + _currentSerial(0), _proceed(), _initDone(), _taskStart(), _taskDone(), - _task(new SimpleTask(_taskStart, _taskDone, &_proceed, - _flushedSerial, _currentSerial)) + _task(std::move(task)) { - if (proceedImmediately) { - _proceed.countDown(); - } } + SimpleTarget(search::SerialNum flushedSerial = 0, bool proceedImmediately = true) : SimpleTarget("anon", flushedSerial, proceedImmediately) { } + SimpleTarget(const std::string &name, search::SerialNum flushedSerial = 0, bool proceedImmediately = true) + : SimpleTarget(name, Type::OTHER, flushedSerial, proceedImmediately) + { } + virtual Time getLastFlushTime() const override { return fastos::ClockSystem::now(); } @@ -304,6 +307,13 @@ public: }; +class GCTarget : public SimpleTarget { +public: + GCTarget(const vespalib::string &name, search::SerialNum flushedSerial) + : SimpleTarget(name, Type::GC, flushedSerial) + {} +}; + class AssertedTarget : public SimpleTarget { public: mutable bool _mgain; @@ -366,10 +376,7 @@ public: public: typedef std::shared_ptr<SimpleStrategy> SP; - SimpleStrategy() - { - // empty - } + SimpleStrategy() {} uint32_t indexOf(const IFlushTarget::SP &target) const @@ -449,6 +456,14 @@ struct Fixture { } + void putFlushHandler(const vespalib::string &docTypeName, IFlushHandler::SP handler) { + engine.putFlushHandler(DocTypeName(docTypeName), handler); + } + + void addTargetToStrategy(IFlushTarget::SP target) { + strategy->_targets.push_back(std::move(target)); + } + std::shared_ptr<SimpleHandler> addSimpleHandler(Targets targets) { @@ -471,21 +486,17 @@ struct Fixture } }; - TEST_F("require that strategy controls flush target", Fixture(1, IINTERVAL)) { vespalib::Gate fooG, barG; std::vector<vespalib::string> order; - FlushTask::UP fooT(new AppendTask("foo", order, fooG)); - FlushTask::UP barT(new AppendTask("bar", order, barG)); - SimpleTarget::SP foo(new SimpleTarget(std::move(fooT), "foo")); - SimpleTarget::SP bar(new SimpleTarget(std::move(barT), "bar")); - f.strategy->_targets.push_back(foo); - f.strategy->_targets.push_back(bar); - - SimpleHandler::SP handler(new SimpleHandler({bar, foo})); - DocTypeName dtnvanon("anon"); - f.engine.putFlushHandler(dtnvanon, handler); + auto foo = std::make_shared<SimpleTarget>(std::make_unique<AppendTask>("foo", order, fooG), "foo"); + auto bar = std::make_shared<SimpleTarget>(std::make_unique<AppendTask>("bar", order, barG), "bar"); + f.addTargetToStrategy(foo); + f.addTargetToStrategy(bar); + + auto handler = std::make_shared<SimpleHandler>(Targets({bar, foo}), "anon"); + f.putFlushHandler("anon", handler); f.engine.start(); EXPECT_TRUE(fooG.await(LONG_TIMEOUT)); @@ -502,25 +513,20 @@ TEST_F("require that zero handlers does not core", Fixture(2, 50)) TEST_F("require that zero targets does not core", Fixture(2, 50)) { - DocTypeName dtnvfoo("foo"); - DocTypeName dtnvbar("bar"); - f.engine.putFlushHandler(dtnvfoo, - IFlushHandler::SP(new SimpleHandler({}, "foo"))); - f.engine.putFlushHandler(dtnvbar, - IFlushHandler::SP(new SimpleHandler({}, "bar"))); + f.putFlushHandler("foo", std::make_shared<SimpleHandler>(Targets(), "foo")); + f.putFlushHandler("bar", std::make_shared<SimpleHandler>(Targets(), "bar")); f.engine.start(); } TEST_F("require that oldest serial is found", Fixture(1, IINTERVAL)) { - SimpleTarget::SP foo(new SimpleTarget("foo", 10)); - SimpleTarget::SP bar(new SimpleTarget("bar", 20)); - f.strategy->_targets.push_back(foo); - f.strategy->_targets.push_back(bar); - - SimpleHandler::SP handler(new SimpleHandler({foo, bar}, "anon", 25)); - DocTypeName dtnvanon("anon"); - f.engine.putFlushHandler(dtnvanon, handler); + auto foo = std::make_shared<SimpleTarget>("foo", 10); + auto bar = std::make_shared<SimpleTarget>("bar", 20); + f.addTargetToStrategy(foo); + f.addTargetToStrategy(bar); + + auto handler = std::make_shared<SimpleHandler>(Targets({foo, bar}), "anon", 25); + f.putFlushHandler("anon", handler); f.engine.start(); EXPECT_TRUE(handler->_done.await(LONG_TIMEOUT)); @@ -529,24 +535,44 @@ TEST_F("require that oldest serial is found", Fixture(1, IINTERVAL)) EXPECT_EQUAL(FlushDoneHistory({ 10, 20, 25 }), handlerFlushDoneHistory); } +TEST_F("require that GC targets are not considered when oldest serial is found", Fixture(1, IINTERVAL)) +{ + auto foo = std::make_shared<SimpleTarget>("foo", 5); + auto bar = std::make_shared<GCTarget>("bar", 10); + auto baz = std::make_shared<SimpleTarget>("baz", 20); + f.addTargetToStrategy(foo); + f.addTargetToStrategy(bar); + f.addTargetToStrategy(baz); + + auto handler = std::make_shared<SimpleHandler>(Targets({foo, bar, baz}), "handler", 25); + f.putFlushHandler("handler", handler); + f.engine.start(); + + // The targets are flushed in sequence: 'foo', 'bar', 'baz' + EXPECT_TRUE(handler->_done.await(LONG_TIMEOUT)); + EXPECT_EQUAL(25ul, handler->_oldestSerial); + + // Before anything is flushed the oldest serial is 5. + // After 'foo' has been flushed the oldest serial is 20 as GC target 'bar' is not considered. + EXPECT_EQUAL(FlushDoneHistory({ 5, 20, 20, 25 }), handler->getFlushDoneHistory()); +} + TEST_F("require that oldest serial is found in group", Fixture(2, IINTERVAL)) { - SimpleTarget::SP fooT1(new SimpleTarget("fooT1", 10)); - SimpleTarget::SP fooT2(new SimpleTarget("fooT2", 20)); - SimpleTarget::SP barT1(new SimpleTarget("barT1", 5)); - SimpleTarget::SP barT2(new SimpleTarget("barT2", 15)); - f.strategy->_targets.push_back(fooT1); - f.strategy->_targets.push_back(fooT2); - f.strategy->_targets.push_back(barT1); - f.strategy->_targets.push_back(barT2); - - SimpleHandler::SP fooH(new SimpleHandler({fooT1, fooT2}, "fooH", 25)); - DocTypeName dtnvfoo("foo"); - f.engine.putFlushHandler(dtnvfoo, fooH); - - SimpleHandler::SP barH(new SimpleHandler({barT1, barT2}, "barH", 20)); - DocTypeName dtnvbar("bar"); - f.engine.putFlushHandler(dtnvbar, barH); + auto fooT1 = std::make_shared<SimpleTarget>("fooT1", 10); + auto fooT2 = std::make_shared<SimpleTarget>("fooT2", 20); + auto barT1 = std::make_shared<SimpleTarget>("barT1", 5); + auto barT2 = std::make_shared<SimpleTarget>("barT2", 15); + f.addTargetToStrategy(fooT1); + f.addTargetToStrategy(fooT2); + f.addTargetToStrategy(barT1); + f.addTargetToStrategy(barT2); + + auto fooH = std::make_shared<SimpleHandler>(Targets({fooT1, fooT2}), "fooH", 25); + f.putFlushHandler("foo", fooH); + + auto barH = std::make_shared<SimpleHandler>(Targets({barT1, barT2}), "barH", 20); + f.putFlushHandler("bar", barH); f.engine.start(); @@ -574,11 +600,10 @@ TEST_F("require that oldest serial is found in group", Fixture(2, IINTERVAL)) TEST_F("require that target can refuse flush", Fixture(2, IINTERVAL)) { - SimpleTarget::SP target(new SimpleTarget()); - SimpleHandler::SP handler(new SimpleHandler({target})); + auto target = std::make_shared<SimpleTarget>(); + auto handler = std::make_shared<SimpleHandler>(Targets({target})); target->_task = searchcorespi::FlushTask::UP(); - DocTypeName dtnvanon("anon"); - f.engine.putFlushHandler(dtnvanon, handler); + f.putFlushHandler("anon", handler); f.engine.start(); EXPECT_TRUE(target->_initDone.await(LONG_TIMEOUT)); @@ -589,10 +614,9 @@ TEST_F("require that target can refuse flush", Fixture(2, IINTERVAL)) TEST_F("require that targets are flushed when nothing new to flush", Fixture(2, IINTERVAL)) { - SimpleTarget::SP target(new SimpleTarget("anon", 5)); // oldest unflushed serial num = 5 - SimpleHandler::SP handler(new SimpleHandler({target}, "anon", 4)); // current serial num = 4 - DocTypeName dtnvanon("anon"); - f.engine.putFlushHandler(dtnvanon, handler); + auto target = std::make_shared<SimpleTarget>("anon", 5); // oldest unflushed serial num = 5 + auto handler = std::make_shared<SimpleHandler>(Targets({target}), "anon", 4); // current serial num = 4 + f.putFlushHandler("anon", handler); f.engine.start(); EXPECT_TRUE(target->_initDone.await(LONG_TIMEOUT)); @@ -602,14 +626,13 @@ TEST_F("require that targets are flushed when nothing new to flush", TEST_F("require that flushing targets are skipped", Fixture(2, IINTERVAL)) { - SimpleTarget::SP foo(new SimpleTarget("foo")); - SimpleTarget::SP bar(new SimpleTarget("bar")); - f.strategy->_targets.push_back(foo); - f.strategy->_targets.push_back(bar); - - SimpleHandler::SP handler(new SimpleHandler({bar, foo})); - DocTypeName dtnvanon("anon"); - f.engine.putFlushHandler(dtnvanon, handler); + auto foo = std::make_shared<SimpleTarget>("foo"); + auto bar = std::make_shared<SimpleTarget>("bar"); + f.addTargetToStrategy(foo); + f.addTargetToStrategy(bar); + + auto handler = std::make_shared<SimpleHandler>(Targets({bar, foo})); + f.putFlushHandler("anon", handler); f.engine.start(); EXPECT_TRUE(foo->_taskDone.await(LONG_TIMEOUT)); @@ -618,12 +641,11 @@ TEST_F("require that flushing targets are skipped", Fixture(2, IINTERVAL)) TEST_F("require that updated targets are not skipped", Fixture(2, IINTERVAL)) { - SimpleTarget::SP target(new SimpleTarget("target", 1)); - f.strategy->_targets.push_back(target); + auto target = std::make_shared<SimpleTarget>("target", 1); + f.addTargetToStrategy(target); - SimpleHandler::SP handler(new SimpleHandler({target}, "handler", 0)); - DocTypeName dtnvhandler("handler"); - f.engine.putFlushHandler(dtnvhandler, handler); + auto handler = std::make_shared<SimpleHandler>(Targets({target}), "handler", 0); + f.putFlushHandler("handler", handler); f.engine.start(); EXPECT_TRUE(target->_taskDone.await(LONG_TIMEOUT)); @@ -633,8 +655,7 @@ TEST("require that threaded target works") { SimpleExecutor executor; SimpleGetSerialNum getSerialNum; - IFlushTarget::SP target(new SimpleTarget()); - target.reset(new ThreadedFlushTarget(executor, getSerialNum, target)); + auto target = std::make_shared<ThreadedFlushTarget>(executor, getSerialNum, std::make_shared<SimpleTarget>()); EXPECT_FALSE(executor._done.await(SHORT_TIMEOUT)); EXPECT_TRUE(target->initFlush(0).get() != NULL); @@ -643,8 +664,7 @@ TEST("require that threaded target works") TEST("require that cached target works") { - IFlushTarget::SP target(new AssertedTarget()); - target.reset(new CachedFlushTarget(target)); + auto target = std::make_shared<CachedFlushTarget>(std::make_shared<AssertedTarget>()); for (uint32_t i = 0; i < 2; ++i) { EXPECT_EQUAL(0l, target->getApproxMemoryGain().getBefore()); EXPECT_EQUAL(0l, target->getApproxMemoryGain().getAfter()); @@ -654,12 +674,11 @@ TEST("require that cached target works") TEST_F("require that trigger flush works", Fixture(2, IINTERVAL)) { - SimpleTarget::SP target(new SimpleTarget("target", 1)); - f.strategy->_targets.push_back(target); + auto target = std::make_shared<SimpleTarget>("target", 1); + f.addTargetToStrategy(target); - SimpleHandler::SP handler(new SimpleHandler({target}, "handler", 9)); - DocTypeName dtnvhandler("handler"); - f.engine.putFlushHandler(dtnvhandler, handler); + auto handler = std::make_shared<SimpleHandler>(Targets({target}), "handler", 9); + f.putFlushHandler("handler", handler); f.engine.start(); f.engine.triggerFlush(); EXPECT_TRUE(target->_initDone.await(LONG_TIMEOUT)); @@ -693,13 +712,13 @@ assertThatHandlersInCurrentSet(FlushEngine & engine, const std::vector<const cha TEST_F("require that concurrency works", Fixture(2, 1)) { - SimpleTarget::SP target1(new SimpleTarget("target1", 1, false)); - SimpleTarget::SP target2(new SimpleTarget("target2", 2, false)); - SimpleTarget::SP target3(new SimpleTarget("target3", 3, false)); - SimpleHandler::SP handler(new SimpleHandler({target1, target2, target3}, "handler", 9)); - DocTypeName dtnvhandler("handler"); - f.engine.putFlushHandler(dtnvhandler, handler); + auto target1 = std::make_shared<SimpleTarget>("target1", 1, false); + auto target2 = std::make_shared<SimpleTarget>("target2", 2, false); + auto target3 = std::make_shared<SimpleTarget>("target3", 3, false); + auto handler = std::make_shared<SimpleHandler>(Targets({target1, target2, target3}), "handler", 9); + f.putFlushHandler("handler", handler); f.engine.start(); + EXPECT_TRUE(target1->_initDone.await(LONG_TIMEOUT)); EXPECT_TRUE(target2->_initDone.await(LONG_TIMEOUT)); EXPECT_TRUE(!target3->_initDone.await(SHORT_TIMEOUT)); @@ -714,11 +733,11 @@ TEST_F("require that concurrency works", Fixture(2, 1)) TEST_F("require that state explorer can list flush targets", Fixture(1, 1)) { - SimpleTarget::SP target = std::make_shared<SimpleTarget>("target1", 100, false); - f.engine.putFlushHandler(DocTypeName("handler"), - std::make_shared<SimpleHandler>( - Targets({target, std::make_shared<SimpleTarget>("target2", 50, true)}), - "handler", 9)); + auto target = std::make_shared<SimpleTarget>("target1", 100, false); + f.putFlushHandler("handler", + std::make_shared<SimpleHandler>( + Targets({target, std::make_shared<SimpleTarget>("target2", 50, true)}), + "handler", 9)); f.engine.start(); target->_initDone.await(LONG_TIMEOUT); target->_taskStart.await(LONG_TIMEOUT); diff --git a/searchcore/src/tests/proton/matching/matching_test.cpp b/searchcore/src/tests/proton/matching/matching_test.cpp index de6a452baf3..7c6779fdc63 100644 --- a/searchcore/src/tests/proton/matching/matching_test.cpp +++ b/searchcore/src/tests/proton/matching/matching_test.cpp @@ -53,6 +53,7 @@ using namespace search; using search::attribute::test::MockAttributeContext; using search::index::schema::DataType; using storage::spi::Timestamp; +using search::fef::indexproperties::hitcollector::HeapSize; void inject_match_phase_limiting(Properties &setup, const vespalib::string &attribute, size_t max_hits, bool descending) { @@ -287,7 +288,7 @@ struct MyWorld { Matcher::SP matcher = createMatcher(); search::fef::Properties overrides; auto mtf = matcher->create_match_tools_factory(*req, searchContext, attributeContext, metaStore, overrides); - auto diversity = mtf->createDiversifier(); + auto diversity = mtf->createDiversifier(HeapSize::lookup(config)); EXPECT_EQUAL(expectDiverse, static_cast<bool>(diversity)); } diff --git a/searchcore/src/vespa/searchcore/config/fdispatchrc.def b/searchcore/src/vespa/searchcore/config/fdispatchrc.def index e00e35c43a0..f9464815f6a 100644 --- a/searchcore/src/vespa/searchcore/config/fdispatchrc.def +++ b/searchcore/src/vespa/searchcore/config/fdispatchrc.def @@ -65,9 +65,6 @@ transport string default="" ## to the delayed ack feature present on various tcp stacks). transportnodelay bool default=true restart -## Decides if a Q is used when sending data. Q will increase throughput. -transportdirectwrite bool default=false restart - ## Minimum size of packets to compress (0 means no compression) ## packetcompresslimit int default = 1024 restart diff --git a/searchcore/src/vespa/searchcore/fdispatch/common/rpc.cpp b/searchcore/src/vespa/searchcore/fdispatch/common/rpc.cpp index 0d2f6dff983..eaff3b90d78 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/common/rpc.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/common/rpc.cpp @@ -43,12 +43,12 @@ FastS_RPC::Init(int port, const vespalib::string &myHeartbeatId) void FastS_RPC::RegisterMethods(FRT_ReflectionBuilder *rb) { - rb->DefineMethod("fs.admin.getNodeType", "", "s", true, + rb->DefineMethod("fs.admin.getNodeType", "", "s", FRT_METHOD(FastS_RPC::RPC_GetNodeType), this); rb->MethodDesc("Get string indicating the node type"); rb->ReturnDesc("type", "node type"); //---------------------------------------------------------------// - rb->DefineMethod("fs.admin.getCompileInfo", "", "*", true, + rb->DefineMethod("fs.admin.getCompileInfo", "", "*", FRT_METHOD(FastS_RPC::RPC_GetCompileInfo), this); rb->MethodDesc("Obtain compile info for this node"); rb->ReturnDesc("info", "any number of descriptive strings"); diff --git a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp index 93680a95d44..b85e706397d 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp @@ -8,6 +8,7 @@ #include <vespa/searchcore/util/eventloop.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/config/helper/configgetter.hpp> +#include <vespa/vespalib/net/crypto_engine.h> #include <vespa/log/log.h> LOG_SETUP(".fdispatch"); @@ -296,7 +297,7 @@ Fdispatch::Init() LOG(debug, "Creating FNET transport"); - _transport = std::make_unique<FNET_Transport>(_config->transportthreads); + _transport = std::make_unique<FNET_Transport>(std::make_shared<vespalib::NullCryptoEngine>(), _config->transportthreads); // disable encryption // grab node slowness limit defaults @@ -340,7 +341,6 @@ Fdispatch::Init() _nodeManager = std::make_unique<FastS_NodeManager>(_componentConfig, this, _partition); GetFNETTransport()->SetTCPNoDelay(_config->transportnodelay); - GetFNETTransport()->SetDirectWrite(_config->transportdirectwrite); if (ptportnum == 0) { throw vespalib::IllegalArgumentException("fdispatchrc.ptportnum must be non-zero, most likely an issue with config delivery."); @@ -349,7 +349,6 @@ Fdispatch::Init() _engineAdapter = std::make_unique<fdispatch::EngineAdapter>(this, _mypool.get()); _transportServer = std::make_unique<TransportServer>(*_engineAdapter, *_engineAdapter, *_engineAdapter, ptportnum, search::engine::TransportServer::DEBUG_ALL); _transportServer->setTCPNoDelay(_config->transportnodelay); - _transportServer->setDirectWrite(_config->transportdirectwrite); if (!_transportServer->start()) { _transportServer.reset(); diff --git a/searchcore/src/vespa/searchcore/fdispatch/program/rpc.cpp b/searchcore/src/vespa/searchcore/fdispatch/program/rpc.cpp index 4217ef6d8c9..56301c5e986 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/program/rpc.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/program/rpc.cpp @@ -9,13 +9,13 @@ FastS_fdispatch_RPC::RegisterMethods(FRT_ReflectionBuilder *rb) { FastS_RPC::RegisterMethods(rb); //------------------------------------------------------------------ - rb->DefineMethod("fs.admin.enableEngine", "s", "i", true, + rb->DefineMethod("fs.admin.enableEngine", "s", "i", FRT_METHOD(FastS_fdispatch_RPC::RPC_EnableEngine), this); rb->MethodDesc("Enable the given engine (clear badness)."); rb->ParamDesc("name", "engine name"); rb->ReturnDesc("count", "number of engines affected"); //------------------------------------------------------------------ - rb->DefineMethod("fs.admin.disableEngine", "s", "i", true, + rb->DefineMethod("fs.admin.disableEngine", "s", "i", FRT_METHOD(FastS_fdispatch_RPC::RPC_DisableEngine), this); rb->MethodDesc("Disable the given engine (mark as admin bad)."); rb->ParamDesc("name", "engine name"); diff --git a/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp b/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp index f2215fff978..38309284e54 100644 --- a/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp +++ b/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp @@ -2,6 +2,7 @@ #include "groupingcontext.h" #include <vespa/searchlib/aggregation/predicates.h> +#include <vespa/searchlib/aggregation/modifiers.h> namespace search { diff --git a/searchcore/src/vespa/searchcore/proton/common/eventlogger.cpp b/searchcore/src/vespa/searchcore/proton/common/eventlogger.cpp index 78f73742fed..5966589d635 100644 --- a/searchcore/src/vespa/searchcore/proton/common/eventlogger.cpp +++ b/searchcore/src/vespa/searchcore/proton/common/eventlogger.cpp @@ -109,13 +109,17 @@ EventLogger::flushStart(const string &name, int64_t beforeMemory, int64_t afterM } void -EventLogger::flushComplete(const string &name, int64_t elapsedTimeMs, +EventLogger::flushComplete(const string &name, int64_t elapsedTimeMs, SerialNum flushed, const string &outputPath, size_t outputPathElems) { JSONStringer jstr; jstr.beginObject(); jstr.appendKey("name").appendString(name); jstr.appendKey("time.elapsed.ms").appendInt64(elapsedTimeMs); + jstr.appendKey("serialnum") + .beginObject() + .appendKey("flushed").appendInt64(flushed) + .endObject(); if (!outputPath.empty()) { jstr.appendKey("output"); LogUtil::logDir(jstr, outputPath, outputPathElems); @@ -124,6 +128,20 @@ EventLogger::flushComplete(const string &name, int64_t elapsedTimeMs, EV_STATE("flush.complete", jstr.toString().data()); } +void +EventLogger::flushPrune(const string &name, SerialNum oldestFlushed) +{ + JSONStringer jstr; + jstr.beginObject(); + jstr.appendKey("name").appendString(name); + jstr.appendKey("serialnum") + .beginObject() + .appendKey("oldestflushed").appendInt64(oldestFlushed) + .endObject(); + jstr.endObject(); + EV_STATE("flush.prune", jstr.toString().data()); +} + namespace { void diff --git a/searchcore/src/vespa/searchcore/proton/common/eventlogger.h b/searchcore/src/vespa/searchcore/proton/common/eventlogger.h index 6ba8852496e..574e650732a 100644 --- a/searchcore/src/vespa/searchcore/proton/common/eventlogger.h +++ b/searchcore/src/vespa/searchcore/proton/common/eventlogger.h @@ -41,8 +41,10 @@ public: SerialNum current); static void flushComplete(const string &name, int64_t elapsedTimeMs, + SerialNum flushed, const string &outputPath, size_t outputPathElems); + static void flushPrune(const string &name, SerialNum oldestFlushed); static void loadAttributeStart(const vespalib::string &subDbName, const vespalib::string &attrName); static void loadAttributeComplete(const vespalib::string &subDbName, const vespalib::string &attrName, int64_t elapsedTimeMs); diff --git a/searchcore/src/vespa/searchcore/proton/feedoperation/CMakeLists.txt b/searchcore/src/vespa/searchcore/proton/feedoperation/CMakeLists.txt index d64fbc6722f..f5e09b81313 100644 --- a/searchcore/src/vespa/searchcore/proton/feedoperation/CMakeLists.txt +++ b/searchcore/src/vespa/searchcore/proton/feedoperation/CMakeLists.txt @@ -16,7 +16,6 @@ vespa_add_library(searchcore_feedoperation STATIC removedocumentsoperation.cpp removeoperation.cpp splitbucketoperation.cpp - spoolerreplayoperation.cpp updateoperation.cpp wipehistoryoperation.cpp DEPENDS diff --git a/searchcore/src/vespa/searchcore/proton/feedoperation/feedoperation.h b/searchcore/src/vespa/searchcore/proton/feedoperation/feedoperation.h index 77b95547bd0..3509af0de5c 100644 --- a/searchcore/src/vespa/searchcore/proton/feedoperation/feedoperation.h +++ b/searchcore/src/vespa/searchcore/proton/feedoperation/feedoperation.h @@ -33,8 +33,6 @@ public: SPLIT_BUCKET = 10, JOIN_BUCKETS = 11, PRUNE_REMOVED_DOCUMENTS = 12, - SPOOLER_REPLAY_START = 13, - SPOOLER_REPLAY_COMPLETE = 14, MOVE = 15, CREATE_BUCKET = 16, COMPACT_LID_SPACE = 17, diff --git a/searchcore/src/vespa/searchcore/proton/feedoperation/operations.h b/searchcore/src/vespa/searchcore/proton/feedoperation/operations.h index 2cdf92fc8b7..df9f22b2462 100644 --- a/searchcore/src/vespa/searchcore/proton/feedoperation/operations.h +++ b/searchcore/src/vespa/searchcore/proton/feedoperation/operations.h @@ -15,7 +15,6 @@ #include "removedocumentsoperation.h" #include "removeoperation.h" #include "splitbucketoperation.h" -#include "spoolerreplayoperation.h" #include "updateoperation.h" #include "wipehistoryoperation.h" diff --git a/searchcore/src/vespa/searchcore/proton/feedoperation/spoolerreplayoperation.cpp b/searchcore/src/vespa/searchcore/proton/feedoperation/spoolerreplayoperation.cpp deleted file mode 100644 index 16ddedc4745..00000000000 --- a/searchcore/src/vespa/searchcore/proton/feedoperation/spoolerreplayoperation.cpp +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include "spoolerreplayoperation.h" -#include <vespa/vespalib/objects/nbostream.h> -#include <vespa/vespalib/util/stringfmt.h> - -using vespalib::make_string; - -namespace proton { - - -SpoolerReplayOperation::SpoolerReplayOperation(Type type) - : FeedOperation(type), - _spoolerSerialNum() -{ -} - -SpoolerReplayOperation::SpoolerReplayOperation(Type type, SerialNum serialNum, SerialNum spoolerSerialNum) - : FeedOperation(type), - _spoolerSerialNum(spoolerSerialNum) -{ - setSerialNum(serialNum); -} - - -void -SpoolerReplayOperation::serialize(vespalib::nbostream &os) const -{ - os << _spoolerSerialNum; -} - - -void -SpoolerReplayOperation::deserialize(vespalib::nbostream &is) -{ - is >> _spoolerSerialNum; -} - -vespalib::string SpoolerReplayOperation::toString() const { - return make_string("SpoolerReplay%s(spoolerSerialNum=%" PRIu64", serialNum=%" PRIu64 ")", - getType() == SPOOLER_REPLAY_START ? "Start" : "Complete", _spoolerSerialNum, getSerialNum()); -} - - -SpoolerReplayStartOperation::SpoolerReplayStartOperation() - : SpoolerReplayOperation(FeedOperation::SPOOLER_REPLAY_START) -{ -} - - -SpoolerReplayStartOperation::SpoolerReplayStartOperation(SerialNum serialNum, SerialNum spoolerSerialNum) - : SpoolerReplayOperation(FeedOperation::SPOOLER_REPLAY_START, - serialNum, - spoolerSerialNum) -{ -} - - -SpoolerReplayCompleteOperation::SpoolerReplayCompleteOperation() - : SpoolerReplayOperation(FeedOperation::SPOOLER_REPLAY_COMPLETE) -{ -} - - -SpoolerReplayCompleteOperation::SpoolerReplayCompleteOperation(SerialNum serialNum, - SerialNum spoolerSerialNum) - : SpoolerReplayOperation(FeedOperation::SPOOLER_REPLAY_COMPLETE, serialNum, spoolerSerialNum) -{ -} - -} // namespace proton diff --git a/searchcore/src/vespa/searchcore/proton/feedoperation/spoolerreplayoperation.h b/searchcore/src/vespa/searchcore/proton/feedoperation/spoolerreplayoperation.h deleted file mode 100644 index 028ad1c6bfa..00000000000 --- a/searchcore/src/vespa/searchcore/proton/feedoperation/spoolerreplayoperation.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#pragma once - -#include "feedoperation.h" - -namespace proton { - -class SpoolerReplayOperation : public FeedOperation -{ -private: - SerialNum _spoolerSerialNum; -protected: - SpoolerReplayOperation(Type type); - SpoolerReplayOperation(Type type, SerialNum serialNum, SerialNum spoolerSerialNum); -public: - ~SpoolerReplayOperation() override {} - SerialNum getSpoolerSerialNum() const { return _spoolerSerialNum; } - void serialize(vespalib::nbostream &os) const override; - void deserialize(vespalib::nbostream &is, const document::DocumentTypeRepo &) override { - deserialize(is); - } - void deserialize(vespalib::nbostream &is); - virtual vespalib::string toString() const override; -}; - - -/** - * Indicate that we are starting replaying the spooler log. - */ -class SpoolerReplayStartOperation : public SpoolerReplayOperation -{ -public: - SpoolerReplayStartOperation(); - /** - * @param serialNum the current serial number of the transaction log. - * @param spoolerSerialNum the serial number of the first entry of the spooler log replay. - */ - SpoolerReplayStartOperation(SerialNum serialNum, SerialNum spoolerSerialNum); -}; - - -/** - * Indicate that we are complete replaying the spooler log. - */ -class SpoolerReplayCompleteOperation : public SpoolerReplayOperation -{ -public: - SpoolerReplayCompleteOperation(); - /** - * @param serialNum the current serial number of the transaction log. - * @param spoolerSerialNum the serial number of the last entry of the spooler log replay. - */ - SpoolerReplayCompleteOperation(SerialNum serialNum, SerialNum spoolerSerialNum); -}; - -} // namespace proton - diff --git a/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp b/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp index 0d2c556b4d6..f7e0b7981bb 100644 --- a/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp +++ b/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp @@ -22,15 +22,23 @@ namespace proton { namespace { -search::SerialNum -findOldestFlushedSerial(const IFlushTarget::List &lst, const IFlushHandler &handler) +std::pair<search::SerialNum, vespalib::string> +findOldestFlushedTarget(const IFlushTarget::List &lst, const IFlushHandler &handler) { - search::SerialNum ret(handler.getCurrentSerialNumber()); - for (const IFlushTarget::SP & target : lst) { - ret = std::min(ret, target->getFlushedSerialNum()); + search::SerialNum oldestFlushedSerial = handler.getCurrentSerialNumber(); + vespalib::string oldestFlushedName = "null"; + for (const IFlushTarget::SP &target : lst) { + if (target->getType() != IFlushTarget::Type::GC) { + search::SerialNum targetFlushedSerial = target->getFlushedSerialNum(); + if (targetFlushedSerial <= oldestFlushedSerial) { + oldestFlushedSerial = targetFlushedSerial; + oldestFlushedName = target->getName(); + } + } } - LOG(debug, "Oldest flushed serial for '%s' is %" PRIu64 ".", handler.getName().c_str(), ret); - return ret; + LOG(debug, "Oldest flushed serial for handler='%s', target='%s': %" PRIu64 ".", + handler.getName().c_str(), oldestFlushedName.c_str(), oldestFlushedSerial); + return std::make_pair(oldestFlushedSerial, oldestFlushedName); } void @@ -174,6 +182,16 @@ FlushEngine::Run(FastOS_ThreadInterface *, void *) prune(); } +namespace { + +vespalib::string +createName(const IFlushHandler &handler, const vespalib::string &targetName) +{ + return (handler.getName() + "." + targetName); +} + +} + bool FlushEngine::prune() { @@ -187,7 +205,11 @@ FlushEngine::prune() } for (const auto &handler : toPrune) { IFlushTarget::List lst = handler->getFlushTargets(); - handler->flushDone(findOldestFlushedSerial(lst, *handler)); + auto oldestFlushed = findOldestFlushedTarget(lst, *handler); + if (LOG_WOULD_LOG(event)) { + EventLogger::flushPrune(createName(*handler, oldestFlushed.second), oldestFlushed.first); + } + handler->flushDone(oldestFlushed.first); } return true; } @@ -333,7 +355,8 @@ FlushEngine::flushDone(const FlushContext &ctx, uint32_t taskId) } if (LOG_WOULD_LOG(event)) { FlushStats stats = ctx.getTarget()->getLastFlushStats(); - EventLogger::flushComplete(ctx.getName(), duration.ms(), stats.getPath(), stats.getPathElementsToLog()); + EventLogger::flushComplete(ctx.getName(), duration.ms(), ctx.getTarget()->getFlushedSerialNum(), + stats.getPath(), stats.getPathElementsToLog()); } LOG(debug, "FlushEngine::flushDone(taskId='%d') took '%f' secs", taskId, duration.sec()); std::lock_guard<std::mutex> guard(_lock); diff --git a/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.cpp b/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.cpp index 52c249fe13b..4f14f709d29 100644 --- a/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.cpp +++ b/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.cpp @@ -10,11 +10,7 @@ InitializerTask::InitializerTask() { } - -InitializerTask::~InitializerTask() -{ -} - +InitializerTask::~InitializerTask() = default; void InitializerTask::addDependency(SP dependency) @@ -22,5 +18,4 @@ InitializerTask::addDependency(SP dependency) _dependencies.emplace_back(std::move(dependency)); } -} // namespace proton::initializer - +} diff --git a/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.h b/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.h index b84db9d6402..ecf98b86fc4 100644 --- a/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.h +++ b/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.h @@ -4,9 +4,7 @@ #include <memory> #include <vector> -namespace proton { - -namespace initializer { +namespace proton::initializer { /* * Class representign an initializer task, used to load a data @@ -35,6 +33,4 @@ public: virtual void run() = 0; }; -} // namespace proton::initializer - -} // namespace proton +} diff --git a/searchcore/src/vespa/searchcore/proton/initializer/task_runner.cpp b/searchcore/src/vespa/searchcore/proton/initializer/task_runner.cpp index 770f00dc264..86c2b525113 100644 --- a/searchcore/src/vespa/searchcore/proton/initializer/task_runner.cpp +++ b/searchcore/src/vespa/searchcore/proton/initializer/task_runner.cpp @@ -92,8 +92,7 @@ TaskRunner::runTask(InitializerTask::SP task) vespalib::ThreadStackExecutor executor(1, 128 * 1024); std::promise<void> promise; auto future = promise.get_future(); - runTask(task, executor, - makeLambdaTask([&]() { promise.set_value(); })); + runTask(task, executor, makeLambdaTask([&]() { promise.set_value(); })); future.wait(); } @@ -119,8 +118,7 @@ TaskRunner::runTask(InitializerTask::SP rootTask, vespalib::Executor &contextExecutor, vespalib::Executor::Task::UP doneTask) { - Context::SP context(std::make_shared<Context>(rootTask, contextExecutor, - std::move(doneTask))); + auto context(std::make_shared<Context>(rootTask, contextExecutor, std::move(doneTask))); context->execute(makeLambdaTask([=]() { pollTask(context); } )); } diff --git a/searchcore/src/vespa/searchcore/proton/initializer/task_runner.h b/searchcore/src/vespa/searchcore/proton/initializer/task_runner.h index 3b52936917c..f28c46334bc 100644 --- a/searchcore/src/vespa/searchcore/proton/initializer/task_runner.h +++ b/searchcore/src/vespa/searchcore/proton/initializer/task_runner.h @@ -6,9 +6,7 @@ #include <vespa/vespalib/stllike/hash_set.h> #include <cassert> -namespace proton { - -namespace initializer { +namespace proton::initializer { /* * Class to run multiple init tasks with dependent tasks. @@ -46,20 +44,15 @@ class TaskRunner { void schedulePoll(); }; void getReadyTasks(const InitializerTask::SP task, TaskList &readyTasks, TaskSet &checked); - void setTaskRunning(InitializerTask &task); - void setTaskDone(InitializerTask &task, Context::SP context); - void internalRunTask(InitializerTask::SP task, Context::SP context); - void internalRunTasks(const TaskList &taskList, Context::SP context); - void pollTask(Context::SP context); public: TaskRunner(vespalib::Executor &executor); - virtual ~TaskRunner(); + ~TaskRunner(); // Depecreated blocking API void runTask(InitializerTask::SP task); @@ -70,6 +63,4 @@ public: vespalib::Executor::Task::UP doneTask); }; -} // namespace proton::initializer - -} // namespace proton +} diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_master.cpp b/searchcore/src/vespa/searchcore/proton/matching/match_master.cpp index d974be1ce3a..4d49e9b5d1b 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_master.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/match_master.cpp @@ -64,7 +64,7 @@ MatchMaster::match(const MatchParams ¶ms, fastos::StopWatch query_latency_time; query_latency_time.start(); vespalib::DualMergeDirector mergeDirector(threadBundle.size()); - MatchLoopCommunicator communicator(threadBundle.size(), params.heapSize, mtf.createDiversifier()); + MatchLoopCommunicator communicator(threadBundle.size(), params.heapSize, mtf.createDiversifier(params.heapSize)); TimedMatchLoopCommunicator timedCommunicator(communicator); DocidRangeScheduler::UP scheduler = createScheduler(threadBundle.size(), numSearchPartitions, params.numDocs); diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp b/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp index a00a90d7a10..28d56b7e0a2 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp @@ -204,7 +204,7 @@ MatchToolsFactory::createMatchTools() const } std::unique_ptr<IDiversifier> -MatchToolsFactory::createDiversifier() const +MatchToolsFactory::createDiversifier(uint32_t heapSize) const { if ( !_diversityParams.enabled() ) { return std::unique_ptr<IDiversifier>(); @@ -214,8 +214,8 @@ MatchToolsFactory::createDiversifier() const LOG(warning, "Skipping diversity due to no %s attribute.", _diversityParams.attribute.c_str()); return std::unique_ptr<IDiversifier>(); } - size_t max_per_group = _rankSetup.getHeapSize()/_diversityParams.min_groups; - return DiversityFilter::create(*attr, _rankSetup.getHeapSize(), max_per_group, _diversityParams.min_groups, + size_t max_per_group = heapSize/_diversityParams.min_groups; + return DiversityFilter::create(*attr, heapSize, max_per_group, _diversityParams.min_groups, _diversityParams.cutoff_strategy == DiversityParams::CutoffStrategy::STRICT); } diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_tools.h b/searchcore/src/vespa/searchcore/proton/matching/match_tools.h index 8f04eebc50e..0ecf6eb5b78 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_tools.h +++ b/searchcore/src/vespa/searchcore/proton/matching/match_tools.h @@ -124,7 +124,7 @@ public: const MaybeMatchPhaseLimiter &match_limiter() const { return *_match_limiter; } MatchTools::UP createMatchTools() const; bool should_diversify() const { return _diversityParams.enabled(); } - std::unique_ptr<search::queryeval::IDiversifier> createDiversifier() const; + std::unique_ptr<search::queryeval::IDiversifier> createDiversifier(uint32_t heapSize) const; search::queryeval::Blueprint::HitEstimate estimate() const { return _query.estimate(); } bool has_first_phase_rank() const { return !_rankSetup.getFirstPhaseRank().empty(); } std::unique_ptr<AttributeOperationTask> createOnMatchTask() const; diff --git a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp index be0a720f1c1..b32af7e3e5a 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp @@ -28,6 +28,7 @@ using search::FeatureSet; using search::attribute::IAttributeContext; using search::fef::MatchDataLayout; using search::fef::MatchData; +using search::fef::indexproperties::hitcollector::HeapSize; using search::queryeval::Blueprint; using search::queryeval::SearchIterator; using vespalib::Doom; @@ -242,14 +243,16 @@ Matcher::match(const SearchRequest &request, vespalib::ThreadBundle &threadBundl return reply; } - MatchParams params(searchContext.getDocIdLimit(), _rankSetup->getHeapSize(), _rankSetup->getArraySize(), + const Properties & rankProperties = request.propertiesMap.rankProperties(); + uint32_t heapSize = HeapSize::lookup(rankProperties, _rankSetup->getHeapSize()); + + MatchParams params(searchContext.getDocIdLimit(), heapSize, _rankSetup->getArraySize(), _rankSetup->getRankScoreDropLimit(), request.offset, request.maxhits, !_rankSetup->getSecondPhaseRank().empty(), !willNotNeedRanking(request, groupingContext)); ResultProcessor rp(attrContext, metaStore, sessionMgr, groupingContext, sessionId, request.sortSpec, params.offset, params.hits, request.should_drop_sort_data()); - const Properties & rankProperties = request.propertiesMap.rankProperties(); size_t numThreadsPerSearch = computeNumThreadsPerSearch(mtf->estimate(), rankProperties); LimitedThreadBundleWrapper limitedThreadBundle(threadBundle, numThreadsPerSearch); MatchMaster master; diff --git a/searchcore/src/vespa/searchcore/proton/server/ddbstate.cpp b/searchcore/src/vespa/searchcore/proton/server/ddbstate.cpp index 09b81f373df..73fff1cfd42 100644 --- a/searchcore/src/vespa/searchcore/proton/server/ddbstate.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/ddbstate.cpp @@ -33,10 +33,7 @@ DDBState::DDBState() } -DDBState::~DDBState() -{ - -} +DDBState::~DDBState() = default; bool diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp index da636068deb..ec6db6fa5b8 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp @@ -1186,6 +1186,13 @@ updateDocumentStoreCacheHitRate(const CacheStats ¤t, const CacheStats &las } void +updateCountMetric(uint64_t currVal, uint64_t lastVal, metrics::LongCountMetric &metric) +{ + uint64_t delta = (currVal >= lastVal) ? (currVal - lastVal) : 0; + metric.inc(delta); +} + +void updateDocstoreMetrics(LegacyDocumentDBMetrics::DocstoreMetrics &metrics, const DocumentSubDBCollection &sub_dbs, CacheStats &lastCacheStats) @@ -1200,7 +1207,7 @@ updateDocstoreMetrics(LegacyDocumentDBMetrics::DocstoreMetrics &metrics, } } metrics.memoryUsage.set(memoryUsage); - metrics.cacheLookups.set(cache_stats.lookups()); + updateCountMetric(cache_stats.lookups(), lastCacheStats.lookups(), metrics.cacheLookups); updateDocumentStoreCacheHitRate(cache_stats, lastCacheStats, metrics.cacheHitRate); metrics.hits = cache_stats.hits; metrics.cacheElements.set(cache_stats.elements); @@ -1225,8 +1232,8 @@ updateDocumentStoreMetrics(DocumentDBTaggedMetrics::SubDBMetrics::DocumentStoreM metrics.cache.memoryUsage.set(cacheStats.memory_used); metrics.cache.elements.set(cacheStats.elements); updateDocumentStoreCacheHitRate(cacheStats, lastCacheStats, metrics.cache.hitRate); - metrics.cache.lookups.set(cacheStats.lookups()); - metrics.cache.invalidations.set(cacheStats.invalidations); + updateCountMetric(cacheStats.lookups(), lastCacheStats.lookups(), metrics.cache.lookups); + updateCountMetric(cacheStats.invalidations, lastCacheStats.invalidations, metrics.cache.invalidations); lastCacheStats = cacheStats; } diff --git a/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp b/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp index ae323bc93df..e45e3d7e423 100644 --- a/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp @@ -112,12 +112,6 @@ public: virtual void replay(const PruneRemovedDocumentsOperation &op) override { _feed_view_ptr->handlePruneRemovedDocuments(op); } - virtual void replay(const SpoolerReplayStartOperation &op) override { - (void) op; - } - virtual void replay(const SpoolerReplayCompleteOperation &op) override { - (void) op; - } virtual void replay(const MoveOperation &op) override { _feed_view_ptr->handleMove(op, search::IDestructorCallback::SP()); } diff --git a/searchcore/src/vespa/searchcore/proton/server/ireplaypackethandler.h b/searchcore/src/vespa/searchcore/proton/server/ireplaypackethandler.h index e93821e2a36..9acc3a530d3 100644 --- a/searchcore/src/vespa/searchcore/proton/server/ireplaypackethandler.h +++ b/searchcore/src/vespa/searchcore/proton/server/ireplaypackethandler.h @@ -16,8 +16,6 @@ class DeleteBucketOperation; class SplitBucketOperation; class JoinBucketsOperation; class PruneRemovedDocumentsOperation; -class SpoolerReplayStartOperation; -class SpoolerReplayCompleteOperation; class MoveOperation; class CreateBucketOperation; class CompactLidSpaceOperation; @@ -41,8 +39,6 @@ struct IReplayPacketHandler virtual void replay(const SplitBucketOperation &op) = 0; virtual void replay(const JoinBucketsOperation &op) = 0; virtual void replay(const PruneRemovedDocumentsOperation &op) = 0; - virtual void replay(const SpoolerReplayStartOperation &op) = 0; - virtual void replay(const SpoolerReplayCompleteOperation &op) = 0; virtual void replay(const MoveOperation &op) = 0; virtual void replay(const CreateBucketOperation &op) = 0; virtual void replay(const CompactLidSpaceOperation &op) = 0; diff --git a/searchcore/src/vespa/searchcore/proton/server/replaypacketdispatcher.cpp b/searchcore/src/vespa/searchcore/proton/server/replaypacketdispatcher.cpp index 42451f08315..447bd0d8624 100644 --- a/searchcore/src/vespa/searchcore/proton/server/replaypacketdispatcher.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/replaypacketdispatcher.cpp @@ -74,14 +74,6 @@ ReplayPacketDispatcher::replayEntry(const Packet::Entry &entry) PruneRemovedDocumentsOperation op; replay(op, is, entry); break; - } case FeedOperation::SPOOLER_REPLAY_START: { - SpoolerReplayStartOperation op; - replay(op, is, entry); - break; - } case FeedOperation::SPOOLER_REPLAY_COMPLETE: { - SpoolerReplayCompleteOperation op; - replay(op, is, entry); - break; } case FeedOperation::MOVE: { MoveOperation op; replay(op, is, entry); diff --git a/searchcore/src/vespa/searchcore/proton/server/rpc_hooks.cpp b/searchcore/src/vespa/searchcore/proton/server/rpc_hooks.cpp index ab012760762..6e442f472b1 100644 --- a/searchcore/src/vespa/searchcore/proton/server/rpc_hooks.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/rpc_hooks.cpp @@ -127,7 +127,7 @@ RPCHooksBase::initRPC() FRT_ReflectionBuilder rb(_orb.get()); //------------------------------------------------------------------------- - rb.DefineMethod("pandora.rtc.getState", "ii", "SSi", true, + rb.DefineMethod("pandora.rtc.getState", "ii", "SSi", FRT_METHOD(RPCHooksBase::rpc_GetState), this); rb.MethodDesc("Get the current state of node"); rb.ParamDesc("gencnt", "old state generation held by the client"); @@ -136,7 +136,7 @@ RPCHooksBase::initRPC() rb.ReturnDesc("values", "Array of state values"); rb.ReturnDesc("newgen", "New state generation count"); //------------------------------------------------------------------------- - rb.DefineMethod("proton.getStatus", "s", "SSSS", true, + rb.DefineMethod("proton.getStatus", "s", "SSSS", FRT_METHOD(RPCHooksBase::rpc_GetProtonStatus), this); rb.MethodDesc("Get the current state of proton or a proton component"); rb.ParamDesc("component", "Which component to check the status for"); @@ -145,7 +145,7 @@ RPCHooksBase::initRPC() rb.ReturnDesc("internalStates", "Array of internal states "); rb.ReturnDesc("message", "Array of status messages"); //------------------------------------------------------------------------- - rb.DefineMethod("pandora.rtc.getIncrementalState", "i", "SSi", true, + rb.DefineMethod("pandora.rtc.getIncrementalState", "i", "SSi", FRT_METHOD(RPCHooksBase::rpc_getIncrementalState), this); rb.MethodDesc("Get node state changes since last invocation"); rb.ParamDesc("timeout", "How many milliseconds to wait for state update"); @@ -153,26 +153,26 @@ RPCHooksBase::initRPC() rb.ReturnDesc("values", "Array of state values"); rb.ReturnDesc("dummy", "Dummy value to enable code reuse"); //------------------------------------------------------------------------- - rb.DefineMethod("pandora.rtc.shutdown", "", "", true, + rb.DefineMethod("pandora.rtc.shutdown", "", "", FRT_METHOD(RPCHooksBase::rpc_Shutdown), this); rb.MethodDesc("Shut down the rtc application"); //------------------------------------------------------------------------- - rb.DefineMethod("pandora.rtc.die", "", "", true, + rb.DefineMethod("pandora.rtc.die", "", "", FRT_METHOD(RPCHooksBase::rpc_die), this); rb.MethodDesc("Exit the rtc application without cleanup"); //------------------------------------------------------------------------- - rb.DefineMethod("proton.triggerFlush", "", "b", true, + rb.DefineMethod("proton.triggerFlush", "", "b", FRT_METHOD(RPCHooksBase::rpc_triggerFlush), this); rb.MethodDesc("Tell the node to trigger flush ASAP"); rb.ReturnDesc("success", "Whether or not a flush was triggered."); //------------------------------------------------------------------------- - rb.DefineMethod("proton.prepareRestart", "", "b", true, + rb.DefineMethod("proton.prepareRestart", "", "b", FRT_METHOD(RPCHooksBase::rpc_prepareRestart), this); rb.MethodDesc("Tell the node to prepare for a restart by flushing components " "such that TLS replay time + time spent flushing components is as low as possible"); rb.ReturnDesc("success", "Whether or not prepare for restart was triggered."); //------------------------------------------------------------------------- - rb.DefineMethod("proton.getDocsums", "bix", "bix", true, FRT_METHOD(RPCHooksBase::rpc_getDocSums), this); + rb.DefineMethod("proton.getDocsums", "bix", "bix", FRT_METHOD(RPCHooksBase::rpc_getDocSums), this); rb.MethodDesc("Get list of document summaries"); rb.ParamDesc("encoding", "0=raw, 6=lz4"); rb.ParamDesc("uncompressedBlobSize", "Uncompressed blob size"); diff --git a/searchcore/src/vespa/searchcore/proton/server/transactionlogmanager.cpp b/searchcore/src/vespa/searchcore/proton/server/transactionlogmanager.cpp index 4bece3e6860..72fcf812ebc 100644 --- a/searchcore/src/vespa/searchcore/proton/server/transactionlogmanager.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/transactionlogmanager.cpp @@ -31,8 +31,7 @@ TransactionLogManager::TransactionLogManager(const vespalib::string &tlsSpec, { } -TransactionLogManager::~TransactionLogManager() { -} +TransactionLogManager::~TransactionLogManager() = default; void TransactionLogManager::init(SerialNum oldestConfigSerial, diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt index 1ad8f562384..c55aadb5eae 100644 --- a/searchlib/CMakeLists.txt +++ b/searchlib/CMakeLists.txt @@ -128,6 +128,7 @@ vespa_define_module( src/tests/engine/monitorapi src/tests/engine/searchapi src/tests/engine/transportserver + src/tests/expression/attributenode src/tests/features src/tests/features/beta src/tests/features/constant diff --git a/searchlib/src/main/java/com/yahoo/searchlib/expression/AttributeMapLookupNode.java b/searchlib/src/main/java/com/yahoo/searchlib/expression/AttributeMapLookupNode.java new file mode 100644 index 00000000000..d15b4086e42 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/expression/AttributeMapLookupNode.java @@ -0,0 +1,92 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.expression; + +import com.yahoo.vespa.objects.Deserializer; +import com.yahoo.vespa.objects.ObjectVisitor; +import com.yahoo.vespa.objects.Serializer; + +import java.util.Objects; + +/** + * This function is an instruction to do a lookup in a map attribute, returning the value. + * + * The key is either specified explicitly or found via a key source attribute. + * Two underlying attributes are used to represent the map attribute (the key and value attributes). + * + * @author geirst + */ +public class AttributeMapLookupNode extends AttributeNode { + + public static final int classId = registerClass(0x4000 + 145, AttributeMapLookupNode.class); + private String keyAttribute = ""; + private String valueAttribute = ""; + private String key = ""; + private String keySourceAttribute = ""; + + private AttributeMapLookupNode(String attributeExpression, String keyAttribute, String valueAttribute, + String key, String keySourceAttribute) { + super(attributeExpression); + this.keyAttribute = keyAttribute; + this.valueAttribute = valueAttribute; + this.key = key; + this.keySourceAttribute = keySourceAttribute; + } + + public AttributeMapLookupNode() { + } + + public static AttributeMapLookupNode fromKey(String attributeExpression, String keyAttribute, String valueAttribute, String key) { + return new AttributeMapLookupNode(attributeExpression, keyAttribute, valueAttribute, key, ""); + } + + public static AttributeMapLookupNode fromKeySourceAttribute(String attributeExpression, String keyAttribute, String valueAttribute, String keySourceAttribute) { + return new AttributeMapLookupNode(attributeExpression, keyAttribute, valueAttribute, "", keySourceAttribute); + } + + @Override + protected int onGetClassId() { + return classId; + } + + @Override + protected void onSerialize(Serializer buf) { + super.onSerialize(buf); + putUtf8(buf, keyAttribute); + putUtf8(buf, valueAttribute); + putUtf8(buf, key); + putUtf8(buf, keySourceAttribute); + } + + @Override + protected void onDeserialize(Deserializer buf) { + super.onDeserialize(buf); + keyAttribute = getUtf8(buf); + valueAttribute = getUtf8(buf); + key = getUtf8(buf); + keySourceAttribute = getUtf8(buf); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), keyAttribute, valueAttribute, key, keySourceAttribute); + } + + @Override + protected boolean equalsFunction(FunctionNode obj) { + AttributeMapLookupNode that = (AttributeMapLookupNode) obj; + return super.equalsFunction(obj) && + Objects.equals(keyAttribute, that.keyAttribute) && + Objects.equals(valueAttribute, that.valueAttribute) && + Objects.equals(key, that.key) && + Objects.equals(keySourceAttribute, that.keySourceAttribute); + } + + @Override + public void visitMembers(ObjectVisitor visitor) { + super.visitMembers(visitor); + visitor.visit("keyAttribute", keyAttribute); + visitor.visit("valueAttribute", valueAttribute); + visitor.visit("key", key); + visitor.visit("keySourceAttribute", keySourceAttribute); + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index f7fe91cb56f..ac5eefcc5b2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -23,6 +23,7 @@ public class ImportedModel { private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); private final String name; + private final String source; private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> arguments = new HashMap<>(); @@ -36,16 +37,21 @@ public class ImportedModel { * Creates a new imported model. * * @param name the name of this mode, containing only characters in [A-Za-z0-9_] + * @param source the source path (directory or file) of this model */ - public ImportedModel(String name) { + public ImportedModel(String name, String source) { if ( ! nameRegexp.matcher(name).matches()) throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'"); this.name = name; + this.source = source; } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ public String name() { return name; } + /** Returns the source path (directiry or file) of this model */ + public String source() { return source; } + /** Returns an immutable map of the arguments ("Placeholders") of this */ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java index 92cb8c3f360..40d1ca8030a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java @@ -6,7 +6,10 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.path.Path; import java.io.File; +import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; /** @@ -30,25 +33,30 @@ public class ImportedModels { } public ImportedModels(File modelsDirectory) { - ImmutableMap.Builder<String, ImportedModel> builder = new ImmutableMap.Builder<>(); + Map<String, ImportedModel> models = new HashMap<>(); // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, builder); - importedModels = builder.build(); + importRecursively(modelsDirectory, models); + importedModels = ImmutableMap.copyOf(models); } - private static void importRecursively(File dir, ImmutableMap.Builder<String, ImportedModel> builder) { + private static void importRecursively(File dir, Map<String, ImportedModel> models) { if ( ! dir.isDirectory()) return; - for (File child : dir.listFiles()) { + + Arrays.stream(dir.listFiles()).sorted().forEach(child -> { Optional<ModelImporter> importer = findImporterOf(child); if (importer.isPresent()) { String name = toName(child); - builder.put(name, importer.get().importModel(name, child)); + ImportedModel existing = models.get(name); + if (existing != null) + throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + + " both resolve to the model name '" + name + "'"); + models.put(name, importer.get().importModel(name, child)); } else { - importRecursively(child, builder); + importRecursively(child, models); } - } + }); } private static Optional<ModelImporter> findImporterOf(File path) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index 13718935cef..2ae107a5770 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -32,7 +32,7 @@ public abstract class ModelImporter { private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); - /** Returns whether the file or directory at the given path is of the tyope which can be imported by this */ + /** Returns whether the file or directory at the given path is of the type which can be imported by this */ public abstract boolean canImport(String modelPath); /** Imports the given model */ @@ -46,8 +46,8 @@ public abstract class ModelImporter { * Takes an IntermediateGraph and converts it to a ImportedModel containing * the actual Vespa ranking expressions. */ - static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) { - ImportedModel model = new ImportedModel(graph.name()); + static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { + ImportedModel model = new ImportedModel(graph.name(), modelSource); graph.optimize(); @@ -139,7 +139,7 @@ public abstract class ModelImporter { Value value = operation.getConstantValue().orElseThrow(() -> new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + - "is constant but does not have a value.")); + "is constant but does not have a value.")); if ( ! (value instanceof TensorValue)) { return operation.function(); // scalar values are inserted directly into the expression } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java index 187e2f2e29d..917b0d6a389 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java @@ -31,7 +31,7 @@ public class OnnxImporter extends ModelImporter { try (FileInputStream inputStream = new FileInputStream(modelPath)) { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph); + return convertIntermediateGraphToModel(graph, modelPath); } catch (IOException e) { throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java index afd01b3d7da..7c18e04bae7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java @@ -39,7 +39,7 @@ public class TensorFlowImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelDir) { try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - return importModel(modelName, model); + return importModel(modelName, modelDir, model); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); @@ -47,10 +47,10 @@ public class TensorFlowImporter extends ModelImporter { } /** Imports a TensorFlow model */ - ImportedModel importModel(String modelName, SavedModelBundle model) { + ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { try { IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph); + return convertIntermediateGraphToModel(graph, modelDir); } catch (IOException e) { throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java index e08214579db..725f319a839 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java @@ -27,7 +27,7 @@ public class XGBoostImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelPath) { try { - ImportedModel model = new ImportedModel(modelName); + ImportedModel model = new ImportedModel(modelName, modelPath); XGBoostParser parser = new XGBoostParser(modelPath); RankingExpression expression = new RankingExpression(parser.toRankingExpression()); model.expression(modelName, expression); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java index 39a8b211d09..eee92862e7f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java @@ -10,8 +10,8 @@ import java.util.Map; import java.util.Set; /** - * Holds an intermediate representation of an imported ONNX or TensorFlow - * graph. After this intermediate representation is constructed, it is used to + * Holds an intermediate representation of an imported model graph. + * After this intermediate representation is constructed, it is used to * simplify and optimize the computational graph and then converted into the * final ImportedModel that holds the Vespa ranking expressions for the model. * diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java index 06d5ad187d8..a541eac2421 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java @@ -12,7 +12,7 @@ import java.util.ArrayList; import java.util.List; /** - * Replaces constant reference pseudofeatures in expressions by their constant value + * Replaces constant reference pseudofeatures which are scalars by their constant value * * @author bratseth */ diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt new file mode 100644 index 00000000000..eb926836576 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt @@ -0,0 +1,7982 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "AddN" + input_arg { + name: "inputs" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "sum" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + type: DT_VARIANT + } + } + } + is_aggregate: true + is_commutative: true + } + op { + name: "ApplyGradientDescent" + input_arg { + name: "var" + type_attr: "T" + is_ref: true + } + input_arg { + name: "alpha" + type_attr: "T" + } + input_arg { + name: "delta" + type_attr: "T" + } + output_arg { + name: "out" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } + } + op { + name: "Assign" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + attr { + name: "validate_shape" + type: "bool" + default_value { + b: true + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: true + } + } + allows_uninitialized_input: true + } + op { + name: "BroadcastGradientArgs" + input_arg { + name: "s0" + type_attr: "T" + } + input_arg { + name: "s1" + type_attr: "T" + } + output_arg { + name: "r0" + type_attr: "T" + } + output_arg { + name: "r1" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Cast" + input_arg { + name: "x" + type_attr: "SrcT" + } + output_arg { + name: "y" + type_attr: "DstT" + } + attr { + name: "SrcT" + type: "type" + } + attr { + name: "DstT" + type: "type" + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "ExpandDims" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dim" + type_attr: "Tdim" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tdim" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Fill" + input_arg { + name: "dims" + type: DT_INT32 + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "FloorDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "GreaterEqual" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type: DT_BOOL + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_UINT16 + type: DT_HALF + } + } + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "InTopKV2" + input_arg { + name: "predictions" + type: DT_FLOAT + } + input_arg { + name: "targets" + type_attr: "T" + } + input_arg { + name: "k" + type_attr: "T" + } + output_arg { + name: "precision" + type: DT_BOOL + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { + name: "Mean" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + is_stateful: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "PreventGradient" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "message" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Prod" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RealDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Reshape" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "shape" + type_attr: "Tshape" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshape" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "ScalarSummary" + input_arg { + name: "tags" + type: DT_STRING + } + input_arg { + name: "values" + type_attr: "T" + } + output_arg { + name: "summary" + type: DT_STRING + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_UINT16 + type: DT_HALF + } + } + } + } + op { + name: "Select" + input_arg { + name: "condition" + type: DT_BOOL + } + input_arg { + name: "t" + type_attr: "T" + } + input_arg { + name: "e" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "Selu" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "SeluGrad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "outputs" + type_attr: "T" + } + output_arg { + name: "backprops" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "Shape" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "SparseSoftmaxCrossEntropyWithLogits" + input_arg { + name: "features" + type_attr: "T" + } + input_arg { + name: "labels" + type_attr: "Tlabels" + } + output_arg { + name: "loss" + type_attr: "T" + } + output_arg { + name: "backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "Tlabels" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Sum" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Tile" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "multiples" + type_attr: "Tmultiples" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tmultiples" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "TruncatedNormal" + input_arg { + name: "shape" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "seed" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "seed2" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true + } + op { + name: "VariableV2" + output_arg { + name: "ref" + type_attr: "dtype" + is_ref: true + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + } + op { + name: "ZerosLike" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + } + tags: "serve" + tensorflow_version: "1.4.1" + tensorflow_git_version: "v1.4.0-19-ga52c8d9" + } + graph_def { + node { + name: "input" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + node { + name: "y" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "shape" + value { + shape { + unknown_rank: true + } + } + } + } + node { + name: "dnn/hidden1/truncated_normal/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\020\003\000\000,\001\000\000" + } + } + } + } + node { + name: "dnn/hidden1/truncated_normal/mean" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/hidden1/truncated_normal/stddev" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0714285746216774 + } + } + } + } + node { + name: "dnn/hidden1/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "dnn/hidden1/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node { + name: "dnn/hidden1/truncated_normal/mul" + op: "Mul" + input: "dnn/hidden1/truncated_normal/TruncatedNormal" + input: "dnn/hidden1/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/truncated_normal" + op: "Add" + input: "dnn/hidden1/truncated_normal/mul" + input: "dnn/hidden1/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/weights" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/hidden1/weights/Assign" + op: "Assign" + input: "dnn/hidden1/weights" + input: "dnn/hidden1/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/hidden1/weights/read" + op: "Identity" + input: "dnn/hidden1/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 300 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/hidden1/bias" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 300 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/hidden1/bias/Assign" + op: "Assign" + input: "dnn/hidden1/bias" + input: "dnn/hidden1/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/hidden1/bias/read" + op: "Identity" + input: "dnn/hidden1/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/MatMul" + op: "MatMul" + input: "input" + input: "dnn/hidden1/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "dnn/hidden1/add" + op: "Add" + input: "dnn/hidden1/MatMul" + input: "dnn/hidden1/bias/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/mul/x" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.009999999776482582 + } + } + } + } + node { + name: "dnn/hidden1/mul" + op: "Mul" + input: "dnn/hidden1/mul/x" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/Maximum" + op: "Maximum" + input: "dnn/hidden1/mul" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden2/truncated_normal/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: ",\001\000\000d\000\000\000" + } + } + } + } + node { + name: "dnn/hidden2/truncated_normal/mean" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/hidden2/truncated_normal/stddev" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.1154700517654419 + } + } + } + } + node { + name: "dnn/hidden2/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "dnn/hidden2/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node { + name: "dnn/hidden2/truncated_normal/mul" + op: "Mul" + input: "dnn/hidden2/truncated_normal/TruncatedNormal" + input: "dnn/hidden2/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/hidden2/truncated_normal" + op: "Add" + input: "dnn/hidden2/truncated_normal/mul" + input: "dnn/hidden2/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/hidden2/weights" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/hidden2/weights/Assign" + op: "Assign" + input: "dnn/hidden2/weights" + input: "dnn/hidden2/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/hidden2/weights/read" + op: "Identity" + input: "dnn/hidden2/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/hidden2/zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 100 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/hidden2/bias" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/hidden2/bias/Assign" + op: "Assign" + input: "dnn/hidden2/bias" + input: "dnn/hidden2/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/hidden2/bias/read" + op: "Identity" + input: "dnn/hidden2/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/hidden2/MatMul" + op: "MatMul" + input: "dnn/hidden1/Maximum" + input: "dnn/hidden2/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "dnn/hidden2/add" + op: "Add" + input: "dnn/hidden2/MatMul" + input: "dnn/hidden2/bias/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/hidden2/Selu" + op: "Selu" + input: "dnn/hidden2/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/outputs/truncated_normal/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "d\000\000\000\n\000\000\000" + } + } + } + } + node { + name: "dnn/outputs/truncated_normal/mean" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/outputs/truncated_normal/stddev" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.20000000298023224 + } + } + } + } + node { + name: "dnn/outputs/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "dnn/outputs/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node { + name: "dnn/outputs/truncated_normal/mul" + op: "Mul" + input: "dnn/outputs/truncated_normal/TruncatedNormal" + input: "dnn/outputs/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "dnn/outputs/truncated_normal" + op: "Add" + input: "dnn/outputs/truncated_normal/mul" + input: "dnn/outputs/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "dnn/outputs/weights" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/outputs/weights/Assign" + op: "Assign" + input: "dnn/outputs/weights" + input: "dnn/outputs/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/outputs/weights/read" + op: "Identity" + input: "dnn/outputs/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "dnn/outputs/zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/outputs/bias" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/outputs/bias/Assign" + op: "Assign" + input: "dnn/outputs/bias" + input: "dnn/outputs/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/outputs/bias/read" + op: "Identity" + input: "dnn/outputs/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "dnn/outputs/MatMul" + op: "MatMul" + input: "dnn/hidden2/Selu" + input: "dnn/outputs/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "dnn/outputs/add" + op: "Add" + input: "dnn/outputs/MatMul" + input: "dnn/outputs/bias/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "loss/SparseSoftmaxCrossEntropyWithLogits/Shape" + op: "Shape" + input: "y" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" + op: "SparseSoftmaxCrossEntropyWithLogits" + input: "dnn/outputs/add" + input: "y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tlabels" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "loss/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "loss/loss" + op: "Mean" + input: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" + input: "loss/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/Shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "train/gradients/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node { + name: "train/gradients/Fill" + op: "Fill" + input: "train/gradients/Shape" + input: "train/gradients/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Reshape/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Reshape" + op: "Reshape" + input: "train/gradients/Fill" + input: "train/gradients/loss/loss_grad/Reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Shape" + op: "Shape" + input: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/loss/loss_grad/Tile" + op: "Tile" + input: "train/gradients/loss/loss_grad/Reshape" + input: "train/gradients/loss/loss_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Shape_1" + op: "Shape" + input: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/loss/loss_grad/Shape_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Prod" + op: "Prod" + input: "train/gradients/loss/loss_grad/Shape_1" + input: "train/gradients/loss/loss_grad/Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/loss/loss_grad/Const_1" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Prod_1" + op: "Prod" + input: "train/gradients/loss/loss_grad/Shape_2" + input: "train/gradients/loss/loss_grad/Const_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/loss/loss_grad/Maximum/y" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Maximum" + op: "Maximum" + input: "train/gradients/loss/loss_grad/Prod_1" + input: "train/gradients/loss/loss_grad/Maximum/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/floordiv" + op: "FloorDiv" + input: "train/gradients/loss/loss_grad/Prod" + input: "train/gradients/loss/loss_grad/Maximum" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Cast" + op: "Cast" + input: "train/gradients/loss/loss_grad/floordiv" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/truediv" + op: "RealDiv" + input: "train/gradients/loss/loss_grad/Tile" + input: "train/gradients/loss/loss_grad/Cast" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/zeros_like" + op: "ZerosLike" + input: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/PreventGradient" + op: "PreventGradient" + input: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "message" + value { + s: "Currently there is no way to take the second derivative of sparse_softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" + } + } + } + node { + name: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } + } + node { + name: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "train/gradients/loss/loss_grad/truediv" + input: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/PreventGradient" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Shape" + op: "Shape" + input: "dnn/outputs/MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "train/gradients/dnn/outputs/add_grad/Shape" + input: "train/gradients/dnn/outputs/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Sum" + op: "Sum" + input: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/mul" + input: "train/gradients/dnn/outputs/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Reshape" + op: "Reshape" + input: "train/gradients/dnn/outputs/add_grad/Sum" + input: "train/gradients/dnn/outputs/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Sum_1" + op: "Sum" + input: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/mul" + input: "train/gradients/dnn/outputs/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Reshape_1" + op: "Reshape" + input: "train/gradients/dnn/outputs/add_grad/Sum_1" + input: "train/gradients/dnn/outputs/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/outputs/add_grad/Reshape" + input: "^train/gradients/dnn/outputs/add_grad/Reshape_1" + } + node { + name: "train/gradients/dnn/outputs/add_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/outputs/add_grad/Reshape" + input: "^train/gradients/dnn/outputs/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/outputs/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/outputs/add_grad/Reshape_1" + input: "^train/gradients/dnn/outputs/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/outputs/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/MatMul_grad/MatMul" + op: "MatMul" + input: "train/gradients/dnn/outputs/add_grad/tuple/control_dependency" + input: "dnn/outputs/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "train/gradients/dnn/outputs/MatMul_grad/MatMul_1" + op: "MatMul" + input: "dnn/hidden2/Selu" + input: "train/gradients/dnn/outputs/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/outputs/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/outputs/MatMul_grad/MatMul" + input: "^train/gradients/dnn/outputs/MatMul_grad/MatMul_1" + } + node { + name: "train/gradients/dnn/outputs/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/outputs/MatMul_grad/MatMul" + input: "^train/gradients/dnn/outputs/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/outputs/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/outputs/MatMul_grad/MatMul_1" + input: "^train/gradients/dnn/outputs/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/outputs/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/Selu_grad/SeluGrad" + op: "SeluGrad" + input: "train/gradients/dnn/outputs/MatMul_grad/tuple/control_dependency" + input: "dnn/hidden2/Selu" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Shape" + op: "Shape" + input: "dnn/hidden2/MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 100 + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "train/gradients/dnn/hidden2/add_grad/Shape" + input: "train/gradients/dnn/hidden2/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Sum" + op: "Sum" + input: "train/gradients/dnn/hidden2/Selu_grad/SeluGrad" + input: "train/gradients/dnn/hidden2/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Reshape" + op: "Reshape" + input: "train/gradients/dnn/hidden2/add_grad/Sum" + input: "train/gradients/dnn/hidden2/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Sum_1" + op: "Sum" + input: "train/gradients/dnn/hidden2/Selu_grad/SeluGrad" + input: "train/gradients/dnn/hidden2/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Reshape_1" + op: "Reshape" + input: "train/gradients/dnn/hidden2/add_grad/Sum_1" + input: "train/gradients/dnn/hidden2/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden2/add_grad/Reshape" + input: "^train/gradients/dnn/hidden2/add_grad/Reshape_1" + } + node { + name: "train/gradients/dnn/hidden2/add_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden2/add_grad/Reshape" + input: "^train/gradients/dnn/hidden2/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden2/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden2/add_grad/Reshape_1" + input: "^train/gradients/dnn/hidden2/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden2/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/MatMul_grad/MatMul" + op: "MatMul" + input: "train/gradients/dnn/hidden2/add_grad/tuple/control_dependency" + input: "dnn/hidden2/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "train/gradients/dnn/hidden2/MatMul_grad/MatMul_1" + op: "MatMul" + input: "dnn/hidden1/Maximum" + input: "train/gradients/dnn/hidden2/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden2/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden2/MatMul_grad/MatMul" + input: "^train/gradients/dnn/hidden2/MatMul_grad/MatMul_1" + } + node { + name: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden2/MatMul_grad/MatMul" + input: "^train/gradients/dnn/hidden2/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden2/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden2/MatMul_grad/MatMul_1" + input: "^train/gradients/dnn/hidden2/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden2/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Shape" + op: "Shape" + input: "dnn/hidden1/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Shape_1" + op: "Shape" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Shape_2" + op: "Shape" + input: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/zeros/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/zeros" + op: "Fill" + input: "train/gradients/dnn/hidden1/Maximum_grad/Shape_2" + input: "train/gradients/dnn/hidden1/Maximum_grad/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/GreaterEqual" + op: "GreaterEqual" + input: "dnn/hidden1/mul" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "train/gradients/dnn/hidden1/Maximum_grad/Shape" + input: "train/gradients/dnn/hidden1/Maximum_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Select" + op: "Select" + input: "train/gradients/dnn/hidden1/Maximum_grad/GreaterEqual" + input: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency" + input: "train/gradients/dnn/hidden1/Maximum_grad/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Select_1" + op: "Select" + input: "train/gradients/dnn/hidden1/Maximum_grad/GreaterEqual" + input: "train/gradients/dnn/hidden1/Maximum_grad/zeros" + input: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Sum" + op: "Sum" + input: "train/gradients/dnn/hidden1/Maximum_grad/Select" + input: "train/gradients/dnn/hidden1/Maximum_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Reshape" + op: "Reshape" + input: "train/gradients/dnn/hidden1/Maximum_grad/Sum" + input: "train/gradients/dnn/hidden1/Maximum_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Sum_1" + op: "Sum" + input: "train/gradients/dnn/hidden1/Maximum_grad/Select_1" + input: "train/gradients/dnn/hidden1/Maximum_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Reshape_1" + op: "Reshape" + input: "train/gradients/dnn/hidden1/Maximum_grad/Sum_1" + input: "train/gradients/dnn/hidden1/Maximum_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden1/Maximum_grad/Reshape" + input: "^train/gradients/dnn/hidden1/Maximum_grad/Reshape_1" + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden1/Maximum_grad/Reshape" + input: "^train/gradients/dnn/hidden1/Maximum_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/Maximum_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden1/Maximum_grad/Reshape_1" + input: "^train/gradients/dnn/hidden1/Maximum_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/Maximum_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Shape_1" + op: "Shape" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "train/gradients/dnn/hidden1/mul_grad/Shape" + input: "train/gradients/dnn/hidden1/mul_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/mul" + op: "Mul" + input: "train/gradients/dnn/hidden1/Maximum_grad/tuple/control_dependency" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Sum" + op: "Sum" + input: "train/gradients/dnn/hidden1/mul_grad/mul" + input: "train/gradients/dnn/hidden1/mul_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Reshape" + op: "Reshape" + input: "train/gradients/dnn/hidden1/mul_grad/Sum" + input: "train/gradients/dnn/hidden1/mul_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/mul_1" + op: "Mul" + input: "dnn/hidden1/mul/x" + input: "train/gradients/dnn/hidden1/Maximum_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Sum_1" + op: "Sum" + input: "train/gradients/dnn/hidden1/mul_grad/mul_1" + input: "train/gradients/dnn/hidden1/mul_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Reshape_1" + op: "Reshape" + input: "train/gradients/dnn/hidden1/mul_grad/Sum_1" + input: "train/gradients/dnn/hidden1/mul_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden1/mul_grad/Reshape" + input: "^train/gradients/dnn/hidden1/mul_grad/Reshape_1" + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden1/mul_grad/Reshape" + input: "^train/gradients/dnn/hidden1/mul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/mul_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden1/mul_grad/Reshape_1" + input: "^train/gradients/dnn/hidden1/mul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/mul_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/AddN" + op: "AddN" + input: "train/gradients/dnn/hidden1/Maximum_grad/tuple/control_dependency_1" + input: "train/gradients/dnn/hidden1/mul_grad/tuple/control_dependency_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/Maximum_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Shape" + op: "Shape" + input: "dnn/hidden1/MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 300 + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "train/gradients/dnn/hidden1/add_grad/Shape" + input: "train/gradients/dnn/hidden1/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Sum" + op: "Sum" + input: "train/gradients/AddN" + input: "train/gradients/dnn/hidden1/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Reshape" + op: "Reshape" + input: "train/gradients/dnn/hidden1/add_grad/Sum" + input: "train/gradients/dnn/hidden1/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Sum_1" + op: "Sum" + input: "train/gradients/AddN" + input: "train/gradients/dnn/hidden1/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Reshape_1" + op: "Reshape" + input: "train/gradients/dnn/hidden1/add_grad/Sum_1" + input: "train/gradients/dnn/hidden1/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden1/add_grad/Reshape" + input: "^train/gradients/dnn/hidden1/add_grad/Reshape_1" + } + node { + name: "train/gradients/dnn/hidden1/add_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden1/add_grad/Reshape" + input: "^train/gradients/dnn/hidden1/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden1/add_grad/Reshape_1" + input: "^train/gradients/dnn/hidden1/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/MatMul_grad/MatMul" + op: "MatMul" + input: "train/gradients/dnn/hidden1/add_grad/tuple/control_dependency" + input: "dnn/hidden1/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "train/gradients/dnn/hidden1/MatMul_grad/MatMul_1" + op: "MatMul" + input: "input" + input: "train/gradients/dnn/hidden1/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden1/MatMul_grad/MatMul" + input: "^train/gradients/dnn/hidden1/MatMul_grad/MatMul_1" + } + node { + name: "train/gradients/dnn/hidden1/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden1/MatMul_grad/MatMul" + input: "^train/gradients/dnn/hidden1/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden1/MatMul_grad/MatMul_1" + input: "^train/gradients/dnn/hidden1/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/GradientDescent/learning_rate" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.009999999776482582 + } + } + } + } + node { + name: "train/GradientDescent/update_dnn/hidden1/weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/hidden1/weights" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/hidden1/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent/update_dnn/hidden1/bias/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/hidden1/bias" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/hidden1/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent/update_dnn/hidden2/weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/hidden2/weights" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent/update_dnn/hidden2/bias/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/hidden2/bias" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/hidden2/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent/update_dnn/outputs/weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/outputs/weights" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/outputs/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent/update_dnn/outputs/bias/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/outputs/bias" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/outputs/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent" + op: "NoOp" + input: "^train/GradientDescent/update_dnn/hidden1/weights/ApplyGradientDescent" + input: "^train/GradientDescent/update_dnn/hidden1/bias/ApplyGradientDescent" + input: "^train/GradientDescent/update_dnn/hidden2/weights/ApplyGradientDescent" + input: "^train/GradientDescent/update_dnn/hidden2/bias/ApplyGradientDescent" + input: "^train/GradientDescent/update_dnn/outputs/weights/ApplyGradientDescent" + input: "^train/GradientDescent/update_dnn/outputs/bias/ApplyGradientDescent" + } + node { + name: "eval/in_top_k/InTopKV2/k" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node { + name: "eval/in_top_k/InTopKV2" + op: "InTopKV2" + input: "dnn/outputs/add" + input: "y" + input: "eval/in_top_k/InTopKV2/k" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "eval/Cast" + op: "Cast" + input: "eval/in_top_k/InTopKV2" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "eval/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "eval/Mean" + op: "Mean" + input: "eval/Cast" + input: "eval/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "init" + op: "NoOp" + input: "^dnn/hidden1/weights/Assign" + input: "^dnn/hidden1/bias/Assign" + input: "^dnn/hidden2/weights/Assign" + input: "^dnn/hidden2/bias/Assign" + input: "^dnn/outputs/weights/Assign" + input: "^dnn/outputs/bias/Assign" + } + node { + name: "Accuracy/tags" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Accuracy" + } + } + } + } + node { + name: "Accuracy" + op: "ScalarSummary" + input: "Accuracy/tags" + input: "eval/Mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "model" + } + } + } + } + node { + name: "save/StringJoin/inputs_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp_de3cfc5e8e7e4734ae221577e8fd36a2/part" + } + } + } + } + node { + name: "save/StringJoin" + op: "StringJoin" + input: "save/Const" + input: "save/StringJoin/inputs_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "separator" + value { + s: "" + } + } + } + node { + name: "save/num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "save/ShardedFilename/shard" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "save/ShardedFilename" + op: "ShardedFilename" + input: "save/StringJoin" + input: "save/ShardedFilename/shard" + input: "save/num_shards" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/SaveV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 6 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 6 + } + } + string_val: "dnn/hidden1/bias" + string_val: "dnn/hidden1/weights" + string_val: "dnn/hidden2/bias" + string_val: "dnn/hidden2/weights" + string_val: "dnn/outputs/bias" + string_val: "dnn/outputs/weights" + } + } + } + } + node { + name: "save/SaveV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 6 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 6 + } + } + string_val: "" + string_val: "" + string_val: "" + string_val: "" + string_val: "" + string_val: "" + } + } + } + } + node { + name: "save/SaveV2" + op: "SaveV2" + input: "save/ShardedFilename" + input: "save/SaveV2/tensor_names" + input: "save/SaveV2/shape_and_slices" + input: "dnn/hidden1/bias" + input: "dnn/hidden1/weights" + input: "dnn/hidden2/bias" + input: "dnn/hidden2/weights" + input: "dnn/outputs/bias" + input: "dnn/outputs/weights" + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + } + } + } + } + node { + name: "save/control_dependency" + op: "Identity" + input: "save/ShardedFilename" + input: "^save/SaveV2" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@save/ShardedFilename" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "save/ShardedFilename" + input: "^save/control_dependency" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "save/MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "save/MergeV2Checkpoints/checkpoint_prefixes" + input: "save/Const" + attr { + key: "delete_old_dirs" + value { + b: true + } + } + } + node { + name: "save/Identity" + op: "Identity" + input: "save/Const" + input: "^save/control_dependency" + input: "^save/MergeV2Checkpoints" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/RestoreV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/hidden1/bias" + } + } + } + } + node { + name: "save/RestoreV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2/tensor_names" + input: "save/RestoreV2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign" + op: "Assign" + input: "dnn/hidden1/bias" + input: "save/RestoreV2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_1/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/hidden1/weights" + } + } + } + } + node { + name: "save/RestoreV2_1/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_1" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_1/tensor_names" + input: "save/RestoreV2_1/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_1" + op: "Assign" + input: "dnn/hidden1/weights" + input: "save/RestoreV2_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/hidden2/bias" + } + } + } + } + node { + name: "save/RestoreV2_2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_2/tensor_names" + input: "save/RestoreV2_2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_2" + op: "Assign" + input: "dnn/hidden2/bias" + input: "save/RestoreV2_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_3/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/hidden2/weights" + } + } + } + } + node { + name: "save/RestoreV2_3/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_3" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_3/tensor_names" + input: "save/RestoreV2_3/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_3" + op: "Assign" + input: "dnn/hidden2/weights" + input: "save/RestoreV2_3" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_4/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/outputs/bias" + } + } + } + } + node { + name: "save/RestoreV2_4/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_4" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_4/tensor_names" + input: "save/RestoreV2_4/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_4" + op: "Assign" + input: "dnn/outputs/bias" + input: "save/RestoreV2_4" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_5/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/outputs/weights" + } + } + } + } + node { + name: "save/RestoreV2_5/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_5" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_5/tensor_names" + input: "save/RestoreV2_5/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_5" + op: "Assign" + input: "dnn/outputs/weights" + input: "save/RestoreV2_5" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/restore_shard" + op: "NoOp" + input: "^save/Assign" + input: "^save/Assign_1" + input: "^save/Assign_2" + input: "^save/Assign_3" + input: "^save/Assign_4" + input: "^save/Assign_5" + } + node { + name: "save/restore_all" + op: "NoOp" + input: "^save/restore_shard" + } + versions { + producer: 24 + } + } + saver_def { + filename_tensor_name: "save/Const:0" + save_tensor_name: "save/Identity:0" + restore_op_name: "save/restore_all" + max_to_keep: 5 + sharded: true + keep_checkpoint_every_n_hours: 10000.0 + version: V2 + } + collection_def { + key: "summaries" + value { + node_list { + value: "Accuracy:0" + } + } + } + collection_def { + key: "train_op" + value { + node_list { + value: "train/GradientDescent" + } + } + } + collection_def { + key: "trainable_variables" + value { + bytes_list { + value: "\n\025dnn/hidden1/weights:0\022\032dnn/hidden1/weights/Assign\032\032dnn/hidden1/weights/read:02\036dnn/hidden1/truncated_normal:0" + value: "\n\022dnn/hidden1/bias:0\022\027dnn/hidden1/bias/Assign\032\027dnn/hidden1/bias/read:02\023dnn/hidden1/zeros:0" + value: "\n\025dnn/hidden2/weights:0\022\032dnn/hidden2/weights/Assign\032\032dnn/hidden2/weights/read:02\036dnn/hidden2/truncated_normal:0" + value: "\n\022dnn/hidden2/bias:0\022\027dnn/hidden2/bias/Assign\032\027dnn/hidden2/bias/read:02\023dnn/hidden2/zeros:0" + value: "\n\025dnn/outputs/weights:0\022\032dnn/outputs/weights/Assign\032\032dnn/outputs/weights/read:02\036dnn/outputs/truncated_normal:0" + value: "\n\022dnn/outputs/bias:0\022\027dnn/outputs/bias/Assign\032\027dnn/outputs/bias/read:02\023dnn/outputs/zeros:0" + } + } + } + collection_def { + key: "variables" + value { + bytes_list { + value: "\n\025dnn/hidden1/weights:0\022\032dnn/hidden1/weights/Assign\032\032dnn/hidden1/weights/read:02\036dnn/hidden1/truncated_normal:0" + value: "\n\022dnn/hidden1/bias:0\022\027dnn/hidden1/bias/Assign\032\027dnn/hidden1/bias/read:02\023dnn/hidden1/zeros:0" + value: "\n\025dnn/hidden2/weights:0\022\032dnn/hidden2/weights/Assign\032\032dnn/hidden2/weights/read:02\036dnn/hidden2/truncated_normal:0" + value: "\n\022dnn/hidden2/bias:0\022\027dnn/hidden2/bias/Assign\032\027dnn/hidden2/bias/read:02\023dnn/hidden2/zeros:0" + value: "\n\025dnn/outputs/weights:0\022\032dnn/outputs/weights/Assign\032\032dnn/outputs/weights/read:02\036dnn/outputs/truncated_normal:0" + value: "\n\022dnn/outputs/bias:0\022\027dnn/outputs/bias/Assign\032\027dnn/outputs/bias/read:02\023dnn/outputs/zeros:0" + } + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "x" + value { + name: "input:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + outputs { + key: "y" + value { + name: "dnn/outputs/add:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 Binary files differnew file mode 100644 index 00000000000..a7ca01888c7 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index Binary files differnew file mode 100644 index 00000000000..7989c109a3a --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py b/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py new file mode 100644 index 00000000000..26529f67919 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py @@ -0,0 +1,97 @@ + +# Common imports +import numpy as np +import tensorflow as tf + +from tensorflow.examples.tutorials.mnist import input_data +from datetime import datetime + +now = datetime.utcnow().strftime("%Y%m%d%H%M%S") +root_logdir = "tf_logs" +logdir = "{}/run-{}/".format(root_logdir, now) + +mnist = input_data.read_data_sets("/tmp/data/") +X_train = mnist.train.images +X_test = mnist.test.images +y_train = mnist.train.labels.astype("int") +y_test = mnist.test.labels.astype("int") + +n_inputs = 28*28 # MNIST +n_hidden1 = 300 +n_hidden2 = 100 +n_hidden3 = 40 +n_outputs = 10 + +learning_rate = 0.01 +n_epochs = 20 +batch_size = 50 + +input = tf.placeholder(tf.float32, shape=(None, n_inputs), name="input") +y = tf.placeholder(tf.int64, shape=(None), name="y") + + +def neuron_layer(X, n_neurons, name, activation=None): + with tf.name_scope(name): + n_inputs = int(X.get_shape()[1]) + stddev = 2 / np.sqrt(n_inputs) + init = tf.truncated_normal((n_inputs, n_neurons), stddev=stddev) + W = tf.Variable(init, name="weights") + b = tf.Variable(tf.zeros([n_neurons]), name="bias") + Z = tf.matmul(X, W) + b + if activation is not None: + return activation(Z) + else: + return Z + + +def leaky_relu(z, name=None): + return tf.maximum(0.01 * z, z, name=name) + + +with tf.name_scope("dnn"): + hidden1 = neuron_layer(input, n_hidden1, name="hidden1", activation=leaky_relu) + hidden2 = neuron_layer(hidden1, n_hidden2, name="hidden2", activation=tf.nn.selu) + logits = neuron_layer(hidden2, n_outputs, name="outputs") #, activation=tf.nn.sigmoid) + +with tf.name_scope("loss"): + xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits) + loss = tf.reduce_mean(xentropy, name="loss") + +with tf.name_scope("train"): + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + training_op = optimizer.minimize(loss) + +with tf.name_scope("eval"): + correct = tf.nn.in_top_k(logits, y, 1) + accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) + +init = tf.global_variables_initializer() +accuracy_summary = tf.summary.scalar('Accuracy', accuracy) +file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph()) + +with tf.Session() as sess: + init.run() + for epoch in range(n_epochs): + for iteration in range(mnist.train.num_examples // batch_size): + X_batch, y_batch = mnist.train.next_batch(batch_size) + sess.run(training_op, feed_dict={input: X_batch, y: y_batch}) + acc_train = accuracy.eval(feed_dict={input: X_batch, y: y_batch}) + acc_val = accuracy.eval(feed_dict={input: mnist.validation.images, + y: mnist.validation.labels}) + print(epoch, "Train accuracy:", acc_train, "Val accuracy:", acc_val) + + # Save summary for tensorboard + summary_str = accuracy_summary.eval(feed_dict={input: mnist.validation.images, + y: mnist.validation.labels}) + file_writer.add_summary(summary_str, epoch) + + export_path = "saved" + print('Exporting trained model to ', export_path) + builder = tf.saved_model.builder.SavedModelBuilder(export_path) + signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':input}, outputs = {'y':logits}) + builder.add_meta_graph_and_variables(sess, + [tf.saved_model.tag_constants.SERVING], + signature_def_map={'serving_default':signature}) + builder.save(as_text=True) + +file_writer.close()
\ No newline at end of file diff --git a/searchlib/src/test/java/com/yahoo/searchlib/expression/ExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/expression/ExpressionTestCase.java index ad50d3cc3d4..1fa012c83f7 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/expression/ExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/expression/ExpressionTestCase.java @@ -215,6 +215,37 @@ public class ExpressionTestCase { } @Test + public void testAttributeMapLookupNode() { + assertEquals(AttributeMapLookupNode.fromKey("map{\"my_key\"}", "map.key", "map.value", "my_key"), + AttributeMapLookupNode.fromKey("map{\"my_key\"}", "map.key", "map.value", "my_key")); + assertNotEquals(AttributeMapLookupNode.fromKey("map{\"my_key\"}", "map.key", "map.value", "my_key"), + AttributeMapLookupNode.fromKey("null", "map.key", "map.value", "my_key")); + assertNotEquals(AttributeMapLookupNode.fromKey("map{\"my_key\"}", "map.key", "map.value", "my_key"), + AttributeMapLookupNode.fromKey("map{\"my_key\"}", "null", "map.value", "my_key")); + assertNotEquals(AttributeMapLookupNode.fromKey("map{\"my_key\"}", "map.key", "map.value", "my_key"), + AttributeMapLookupNode.fromKey("map{\"my_key\"}", "map.key", "null", "my_key")); + assertNotEquals(AttributeMapLookupNode.fromKey("map{\"my_key\"}", "map.key", "map.value", "my_key"), + AttributeMapLookupNode.fromKey("map{\"my_key\"}", "map.key", "map.value", "null")); + + assertEquals(AttributeMapLookupNode.fromKeySourceAttribute("map{attribute(key_source)}", "map.key", "map.value", "key_source"), + AttributeMapLookupNode.fromKeySourceAttribute("map{attribute(key_source)}", "map.key", "map.value", "key_source")); + assertNotEquals(AttributeMapLookupNode.fromKeySourceAttribute("map{attribute(key_source)}", "map.key", "map.value", "key_source"), + AttributeMapLookupNode.fromKeySourceAttribute("map{attribute(key_source)}", "map.key", "map.value", "null")); + + assertAttributeMapLookupNodeSerialize( + AttributeMapLookupNode.fromKey("map{\"my_key\"}", "map.key", "map.value", "my_key")); + assertAttributeMapLookupNodeSerialize( + AttributeMapLookupNode.fromKeySourceAttribute("map{attribute(key_source)}", "map.key", "map.value", "key_source")); + } + + private static void assertAttributeMapLookupNodeSerialize(AttributeMapLookupNode a) { + AttributeMapLookupNode b = (AttributeMapLookupNode)assertSerialize(a); + assertEquals(a, b); + AttributeMapLookupNode c = (AttributeMapLookupNode)assertSerialize(b); + assertEquals(a, c); + } + + @Test public void testInterpolatedLookupNode() { ExpressionNode argA = new ConstantNode(new FloatResultNode(2.71828182846)); ExpressionNode argB = new ConstantNode(new FloatResultNode(3.14159265359)); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java new file mode 100644 index 00000000000..add66eece1a --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class MnistImportTestCase { + + @Test + public void testMnistImport() { + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist/saved"); + ImportedModel.Signature signature = model.get().signature("serving_default"); + + assertEquals("Has skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); + + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("dnn/outputs/add", output.getName()); + model.assertEqualResultSum("input", output.getName(), 0.00001); + } + + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index a7926cd2e02..bcfc6ce0a04 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -7,9 +7,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; -import org.tensorflow.SavedModelBundle; - -import java.io.IOException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -21,7 +18,7 @@ import static org.junit.Assert.assertTrue; public class OnnxMnistSoftmaxImportTestCase { @Test - public void testMnistSoftmaxImport() throws IOException { + public void testMnistSoftmaxImport() { ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants @@ -43,14 +40,14 @@ public class OnnxMnistSoftmaxImportTestCase { assertEquals(1, model.requiredMacros().size()); assertTrue(model.requiredMacros().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredMacros().get("Placeholder")); + model.requiredMacros().get("Placeholder")); // Check outputs RankingExpression output = model.defaultSignature().outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.getRoot().toString()); + output.getRoot().toString()); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index bd7644be23b..dd6c8095e3c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -13,7 +13,7 @@ import static org.junit.Assert.assertTrue; /** * @author bratseth */ -public class MnistSoftmaxImportTestCase { +public class TensorFlowMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java index 723c5f27914..4de3aa5d635 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java @@ -36,11 +36,26 @@ public class TestableTensorFlowModel { public TestableTensorFlowModel(String modelName, String modelDir) { tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); - model = new TensorFlowImporter().importModel(modelName, tensorFlowModel); + model = new TensorFlowImporter().importModel(modelName, modelDir, tensorFlowModel); } public ImportedModel get() { return model; } + /** Compare that summing the tensors produce the same result to within some tolerance delta */ + public void assertEqualResultSum(String inputName, String operationName, double delta) { + Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); + Context context = contextFrom(model); + Tensor placeholder = placeholderArgument(); + context.put(inputName, new TensorValue(placeholder)); + + model.macros().forEach((k,v) -> evaluateMacro(context, model, k)); + + Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); + assertEquals("Operation '" + operationName + "' produces equal results", + tfResult.sum().asDouble(), vespaResult.sum().asDouble(), delta); + } + + /** Compare tensors 100% exactly */ public void assertEqualResult(String inputName, String operationName) { Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); Context context = contextFrom(model); diff --git a/searchlib/src/tests/expression/attributenode/CMakeLists.txt b/searchlib/src/tests/expression/attributenode/CMakeLists.txt new file mode 100644 index 00000000000..3006c27dd0d --- /dev/null +++ b/searchlib/src/tests/expression/attributenode/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(searchlib_attribute_node_test_app TEST + SOURCES + attribute_node_test.cpp + DEPENDS + searchlib + searchlib_test +) +vespa_add_test(NAME searchlib_attribute_node_test_app COMMAND searchlib_attribute_node_test_app) diff --git a/searchlib/src/tests/expression/attributenode/attribute_node_test.cpp b/searchlib/src/tests/expression/attributenode/attribute_node_test.cpp new file mode 100644 index 00000000000..7490b0699be --- /dev/null +++ b/searchlib/src/tests/expression/attributenode/attribute_node_test.cpp @@ -0,0 +1,429 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchcommon/common/undefinedvalues.h> +#include <vespa/searchlib/attribute/attributefactory.h> +#include <vespa/searchlib/attribute/attributecontext.h> +#include <vespa/searchlib/attribute/attributemanager.h> +#include <vespa/searchlib/attribute/attributevector.h> +#include <vespa/searchlib/attribute/attributevector.hpp> +#include <vespa/searchlib/attribute/floatbase.h> +#include <vespa/searchlib/attribute/integerbase.h> +#include <vespa/searchlib/attribute/stringbase.h> +#include <vespa/searchlib/expression/attributenode.h> +#include <vespa/searchlib/expression/resultvector.h> +#include <vespa/searchlib/test/make_attribute_map_lookup_node.h> +#include <vespa/vespalib/test/insertion_operators.h> +#include <vespa/vespalib/testkit/testapp.h> + +#include <vespa/log/log.h> +LOG_SETUP("attribute_node_test"); + +using search::AttributeContext; +using search::AttributeFactory; +using search::AttributeManager; +using search::AttributeVector; +using search::IntegerAttribute; +using search::FloatingPointAttribute; +using search::StringAttribute; +using search::attribute::BasicType; +using search::attribute::CollectionType; +using search::attribute::Config; +using search::attribute::IAttributeVector; +using search::attribute::getUndefined; +using search::expression::AttributeNode; +using search::expression::EnumResultNode; +using search::expression::EnumResultNodeVector; +using search::expression::FloatResultNode; +using search::expression::FloatResultNodeVector; +using search::expression::Int8ResultNode; +using search::expression::Int8ResultNodeVector; +using search::expression::IntegerResultNodeVector; +using search::expression::IntegerResultNode; +using search::expression::ResultNode; +using search::expression::ResultNodeVector; +using search::expression::StringResultNode; +using search::expression::StringResultNodeVector; +using search::expression::test::makeAttributeMapLookupNode; +using vespalib::BufferRef; + +namespace { + +vespalib::string stringValue(const ResultNode &result, const IAttributeVector &attr) { + if (result.inherits(EnumResultNode::classId)) { + auto enumHandle = result.getEnum(); + auto &stringAttr = dynamic_cast<const StringAttribute &>(attr); + return vespalib::string(stringAttr.getFromEnum(enumHandle)); + } + char buf[100]; + BufferRef bref(&buf[0], sizeof(buf)); + auto sbuf = result.getString(bref); + return vespalib::string(sbuf.c_str(), sbuf.c_str() + sbuf.size()); +} + +struct AttributeManagerFixture +{ + AttributeManager mgr; + + AttributeManagerFixture(); + ~AttributeManagerFixture(); + template <typename AttributeType, typename ValueType> + void buildAttribute(const vespalib::string &name, BasicType type, std::vector<ValueType> values); + void buildStringAttribute(const vespalib::string &name, std::vector<vespalib::string> values); + void buildFloatAttribute(const vespalib::string &name, std::vector<double> values); + void buildIntegerAttribute(const vespalib::string &name, BasicType type, std::vector<IAttributeVector::largeint_t> values); + template <typename AttributeType, typename ValueType> + void buildArrayAttribute(const vespalib::string &name, BasicType type, std::vector<std::vector<ValueType>> values); + void buildStringArrayAttribute(const vespalib::string &name,std::vector<std::vector<vespalib::string>> values); + void buildFloatArrayAttribute(const vespalib::string &name, std::vector<std::vector<double>> values); + void buildIntegerArrayAttribute(const vespalib::string &name, BasicType type, std::vector<std::vector<IAttributeVector::largeint_t>> values); +}; + +AttributeManagerFixture::AttributeManagerFixture() + : mgr() +{ + buildStringAttribute("sfield", { "n1", ""}); + buildIntegerAttribute("ifield", BasicType::Type::INT8, { 10, getUndefined<int8_t>() }); + buildFloatAttribute("ffield", { 110.0, getUndefined<double>() }); + buildStringArrayAttribute("array.name", {{"n1.1", "n1.2"}, {"n2"}, {}}); + buildIntegerArrayAttribute("array.val", BasicType::Type::INT8, {{ 10, 11}, {20, 21 }, {}}); + buildFloatArrayAttribute("array.fval", {{ 110.0}, { 120.0, 121.0 }, {}}); + buildStringArrayAttribute("smap.key", {{"k1.1", "k1.2"}, {"k2"}, {}}); + buildStringArrayAttribute("smap.value.name", {{"n1.1", "n1.2"}, {"n2"}, {}}); + buildIntegerArrayAttribute("smap.value.val", BasicType::Type::INT8, {{ 10, 11}, {20, 21 }, {}}); + buildFloatArrayAttribute("smap.value.fval", {{ 110.0}, { 120.0, 121.0 }, {}}); + buildStringArrayAttribute("map.key", {{"k1.1", "k1.2"}, {"k2"}, {}}); + buildStringArrayAttribute("map.value", {{"n1.1", "n1.2"}, {"n2"}, {}}); + buildStringAttribute("keyfield1", {"k1.2", "k2", "k3"}); + buildStringAttribute("keyfield2", {"k1.1", "k1", "k1"}); +} + +AttributeManagerFixture::~AttributeManagerFixture() = default; + +template <typename AttributeType, typename ValueType> +void +AttributeManagerFixture::buildAttribute(const vespalib::string &name, + BasicType type, + std::vector<ValueType> values) +{ + Config cfg(type, CollectionType::Type::SINGLE); + auto attrBase = AttributeFactory::createAttribute(name, cfg); + EXPECT_TRUE(attrBase); + auto attr = std::dynamic_pointer_cast<AttributeType>(attrBase); + EXPECT_TRUE(attr); + attr->addReservedDoc(); + for (const auto &value : values) { + uint32_t docId = 0; + EXPECT_TRUE(attr->addDoc(docId)); + EXPECT_NOT_EQUAL(0u, docId); + attr->update(docId, value); + attr->commit(); + } + EXPECT_TRUE(mgr.add(attr)); +} + +void +AttributeManagerFixture::buildStringAttribute(const vespalib::string &name, + std::vector<vespalib::string> values) +{ + buildAttribute<StringAttribute, vespalib::string>(name, BasicType::Type::STRING, std::move(values)); +} + +void +AttributeManagerFixture::buildFloatAttribute(const vespalib::string &name, + std::vector<double> values) +{ + buildAttribute<FloatingPointAttribute, double>(name, BasicType::Type::DOUBLE, std::move(values)); +} + +void +AttributeManagerFixture::buildIntegerAttribute(const vespalib::string &name, + BasicType type, + std::vector<IAttributeVector::largeint_t> values) +{ + buildAttribute<IntegerAttribute, IAttributeVector::largeint_t>(name, type, std::move(values)); +} + +template <typename AttributeType, typename ValueType> +void +AttributeManagerFixture::buildArrayAttribute(const vespalib::string &name, + BasicType type, + std::vector<std::vector<ValueType>> values) +{ + Config cfg(type, CollectionType::Type::ARRAY); + auto attrBase = AttributeFactory::createAttribute(name, cfg); + EXPECT_TRUE(attrBase); + auto attr = std::dynamic_pointer_cast<AttributeType>(attrBase); + EXPECT_TRUE(attr); + attr->addReservedDoc(); + for (const auto &docValues : values) { + uint32_t docId = 0; + EXPECT_TRUE(attr->addDoc(docId)); + EXPECT_NOT_EQUAL(0u, docId); + for (const auto &value : docValues) { + attr->append(docId, value, 1); + } + attr->commit(); + } + EXPECT_TRUE(mgr.add(attr)); +} + +void +AttributeManagerFixture::buildStringArrayAttribute(const vespalib::string &name, + std::vector<std::vector<vespalib::string>> values) +{ + buildArrayAttribute<StringAttribute, vespalib::string>(name, BasicType::Type::STRING, std::move(values)); +} + +void +AttributeManagerFixture::buildFloatArrayAttribute(const vespalib::string &name, + std::vector<std::vector<double>> values) +{ + buildArrayAttribute<FloatingPointAttribute, double>(name, BasicType::Type::DOUBLE, std::move(values)); +} + +void +AttributeManagerFixture::buildIntegerArrayAttribute(const vespalib::string &name, + BasicType type, + std::vector<std::vector<IAttributeVector::largeint_t>> values) +{ + buildArrayAttribute<IntegerAttribute, IAttributeVector::largeint_t>(name, type, std::move(values)); +} + + +struct Fixture +{ + AttributeManagerFixture attrs; + AttributeContext context; + Fixture(); + ~Fixture(); + std::unique_ptr<AttributeNode> makeNode(const vespalib::string &attributeName, bool useEnumOptimiation = false, bool preserveAccurateTypes = false); + void assertInts(std::vector<IAttributeVector::largeint_t> expVals, const vespalib::string &attributteName, bool preserveAccurateTypes = false); + void assertStrings(std::vector<vespalib::string> expVals, const vespalib::string &attributteName, bool useEnumOptimization = false); + void assertFloats(std::vector<double> expVals, const vespalib::string &attributteName); + void assertIntArrays(std::vector<std::vector<IAttributeVector::largeint_t>> expVals, const vespalib::string &attributteName, bool preserveAccurateTypes = false); + void assertStringArrays(std::vector<std::vector<vespalib::string>> expVals, const vespalib::string &attributteName, bool useEnumOptimization = false); + void assertFloatArrays(std::vector<std::vector<double>> expVals, const vespalib::string &attributteName); +}; + +Fixture::Fixture() + : attrs(), + context(attrs.mgr) +{ +} + +Fixture::~Fixture() = default; + +std::unique_ptr<AttributeNode> +Fixture::makeNode(const vespalib::string &attributeName, bool useEnumOptimization, bool preserveAccurateTypes) +{ + std::unique_ptr<AttributeNode> node; + if (attributeName.find('{') == vespalib::string::npos) { + node = std::make_unique<AttributeNode>(attributeName); + } else { + node = makeAttributeMapLookupNode(attributeName); + } + if (useEnumOptimization) { + node->useEnumOptimization(); + } + AttributeNode::Configure configure(context); + node->select(configure, configure); + node->prepare(preserveAccurateTypes); + return node; +} + + +void +Fixture::assertInts(std::vector<IAttributeVector::largeint_t> expVals, const vespalib::string &attributeName, bool preserveAccurateTypes) +{ + auto node = makeNode(attributeName, false, preserveAccurateTypes); + uint32_t docId = 0; + for (const auto &expDocVal : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + if (preserveAccurateTypes) { + ASSERT_TRUE(result.inherits(Int8ResultNode::classId)); + } else { + ASSERT_TRUE(result.inherits(IntegerResultNode::classId)); + } + IAttributeVector::largeint_t docVal = result.getInteger(); + EXPECT_EQUAL(expDocVal, docVal); + } +} + +void +Fixture::assertStrings(std::vector<vespalib::string> expVals, const vespalib::string &attributeName, bool useEnumOptimization) +{ + auto node = makeNode(attributeName, useEnumOptimization); + uint32_t docId = 0; + for (const auto &expDocVal : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + if (useEnumOptimization) { + ASSERT_TRUE(result.inherits(EnumResultNode::classId)); + } else { + ASSERT_TRUE(result.inherits(StringResultNode::classId)); + } + vespalib::string docVal = stringValue(result, *node->getAttribute()); + EXPECT_EQUAL(expDocVal, docVal); + } +} + +void +Fixture::assertFloats(std::vector<double> expVals, const vespalib::string &attributeName) +{ + auto node = makeNode(attributeName); + uint32_t docId = 0; + for (const auto &expDocVal : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + ASSERT_TRUE(result.inherits(FloatResultNode::classId)); + double docVal = result.getFloat(); + EXPECT_EQUAL(std::isnan(expDocVal), std::isnan(docVal)); + if (!std::isnan(expDocVal)) { + EXPECT_EQUAL(expDocVal, docVal); + } + } +} + +void +Fixture::assertIntArrays(std::vector<std::vector<IAttributeVector::largeint_t>> expVals, const vespalib::string &attributeName, bool preserveAccurateTypes) +{ + auto node = makeNode(attributeName, false, preserveAccurateTypes); + uint32_t docId = 0; + for (const auto &expDocVals : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + ASSERT_TRUE(result.inherits(ResultNodeVector::classId)); + const auto &resultVector = static_cast<const ResultNodeVector &>(result); + if (preserveAccurateTypes) { + ASSERT_TRUE(result.inherits(Int8ResultNodeVector::classId)); + } else { + ASSERT_TRUE(result.inherits(IntegerResultNodeVector::classId)); + } + std::vector<IAttributeVector::largeint_t> docVals; + for (size_t i = 0; i < resultVector.size(); ++i) { + docVals.push_back(resultVector.get(i).getInteger()); + } + EXPECT_EQUAL(expDocVals, docVals); + } +} + +void +Fixture::assertStringArrays(std::vector<std::vector<vespalib::string>> expVals, const vespalib::string &attributeName, bool useEnumOptimization) +{ + auto node = makeNode(attributeName, useEnumOptimization); + uint32_t docId = 0; + for (const auto &expDocVals : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + ASSERT_TRUE(result.inherits(ResultNodeVector::classId)); + const auto &resultVector = static_cast<const ResultNodeVector &>(result); + if (useEnumOptimization) { + ASSERT_TRUE(result.inherits(EnumResultNodeVector::classId)); + } else { + ASSERT_TRUE(result.inherits(StringResultNodeVector::classId)); + } + std::vector<vespalib::string> docVals; + for (size_t i = 0; i < resultVector.size(); ++i) { + docVals.push_back(stringValue(resultVector.get(i), *node->getAttribute())); + } + EXPECT_EQUAL(expDocVals, docVals); + } +} + +void +Fixture::assertFloatArrays(std::vector<std::vector<double>> expVals, const vespalib::string &attributeName) +{ + auto node = makeNode(attributeName); + uint32_t docId = 0; + for (const auto &expDocVals : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + ASSERT_TRUE(result.inherits(ResultNodeVector::classId)); + const auto &resultVector = static_cast<const ResultNodeVector &>(result); + ASSERT_TRUE(result.inherits(FloatResultNodeVector::classId)); + std::vector<double> docVals; + for (size_t i = 0; i < resultVector.size(); ++i) { + docVals.push_back(resultVector.get(i).getFloat()); + } + EXPECT_EQUAL(expDocVals.size(), docVals.size()); + for (size_t i = 0; i < expDocVals.size(); ++i) { + EXPECT_EQUAL(std::isnan(expDocVals[i]), std::isnan(docVals[i])); + if (!std::isnan(expDocVals[i])) { + EXPECT_EQUAL(expDocVals[i], docVals[i]); + } + } + } +} + +TEST_F("test single values", Fixture) +{ + TEST_DO(f.assertInts({ 10, getUndefined<int8_t>()}, "ifield")); + TEST_DO(f.assertInts({ 10, getUndefined<int8_t>()}, "ifield", true)); + TEST_DO(f.assertStrings({ "n1", "" }, "sfield")); + TEST_DO(f.assertStrings({ "n1", "" }, "sfield", true)); + TEST_DO(f.assertFloats({ 110.0, getUndefined<double>() }, "ffield")); +} + +TEST_F("Test array values", Fixture) +{ + TEST_DO(f.assertIntArrays({{ 10, 11}, {20, 21 }, {}}, "array.val")); + TEST_DO(f.assertIntArrays({{ 10, 11}, {20, 21 }, {}}, "array.val", true)); + TEST_DO(f.assertStringArrays({{"n1.1", "n1.2"}, {"n2"}, {}}, "array.name")); + TEST_DO(f.assertStringArrays({{"n1.1", "n1.2"}, {"n2"}, {}}, "array.name", true)); + TEST_DO(f.assertFloatArrays({{ 110.0}, { 120.0, 121.0 }, {}}, "array.fval")); + TEST_DO(f.assertStringArrays({{"k1.1", "k1.2"}, {"k2"}, {}}, "smap.key")); + TEST_DO(f.assertStringArrays({{"n1.1", "n1.2"}, {"n2"}, {}}, "smap.value.name")); + TEST_DO(f.assertIntArrays({{ 10, 11}, {20, 21 }, {}}, "smap.value.val")); + TEST_DO(f.assertFloatArrays({{ 110.0}, { 120.0, 121.0 }, {}}, "smap.value.fval")); + TEST_DO(f.assertStringArrays({{"k1.1", "k1.2"}, {"k2"}, {}}, "map.key")); + TEST_DO(f.assertStringArrays({{"n1.1", "n1.2"}, {"n2"}, {}}, "map.value")); +} + +TEST_F("test keyed values", Fixture) +{ + TEST_DO(f.assertStrings({"n1.1", "", ""}, "smap{\"k1.1\"}.name")); + TEST_DO(f.assertStrings({"n1.2", "", ""}, "smap{\"k1.2\"}.name")); + TEST_DO(f.assertStrings({"", "n2", ""}, "smap{\"k2\"}.name")); + TEST_DO(f.assertStrings({"", "", ""}, "smap{\"k5\"}.name")); + TEST_DO(f.assertFloats({ 110.0, getUndefined<double>(), getUndefined<double>()}, "smap{\"k1.1\"}.fval")); + TEST_DO(f.assertFloats({ getUndefined<double>(), getUndefined<double>(), getUndefined<double>()}, "smap{\"k1.2\"}.fval")); + TEST_DO(f.assertFloats({ getUndefined<double>(), 120.0, getUndefined<double>()}, "smap{\"k2\"}.fval")); + TEST_DO(f.assertFloats({ getUndefined<double>(), getUndefined<double>(), getUndefined<double>()}, "smap{\"k5\"}.fval")); + TEST_DO(f.assertInts({ 10, getUndefined<int8_t>(), getUndefined<int8_t>()}, "smap{\"k1.1\"}.val")); + TEST_DO(f.assertInts({ 11, getUndefined<int8_t>(), getUndefined<int8_t>()}, "smap{\"k1.2\"}.val")); + TEST_DO(f.assertInts({ getUndefined<int8_t>(), 20, getUndefined<int8_t>()}, "smap{\"k2\"}.val")); + TEST_DO(f.assertInts({ getUndefined<int8_t>(), getUndefined<int8_t>(), getUndefined<int8_t>()}, "smap{\"k5\"}.val")); + TEST_DO(f.assertStrings({"n1.1", "", ""}, "map{\"k1.1\"}")); + TEST_DO(f.assertStrings({"n1.2", "", ""}, "map{\"k1.2\"}")); + TEST_DO(f.assertStrings({"", "n2", ""}, "map{\"k2\"}")); + TEST_DO(f.assertStrings({"", "", ""}, "map{\"k5\"}")); +} + +TEST_F("test indirectly keyed values", Fixture) +{ + TEST_DO(f.assertStrings({"n1.2", "n2", ""}, "map{attribute(keyfield1)}")); + TEST_DO(f.assertStrings({"n1.1", "", ""}, "map{attribute(keyfield2)}")); + TEST_DO(f.assertStrings({"n1.2", "n2", ""}, "smap{attribute(keyfield1)}.name")); + TEST_DO(f.assertStrings({"n1.1", "", ""}, "smap{attribute(keyfield2)}.name")); + TEST_DO(f.assertFloats({ getUndefined<double>(), 120.0, getUndefined<double>()}, "smap{attribute(keyfield1)}.fval")); + TEST_DO(f.assertFloats({ 110.0, getUndefined<double>(), getUndefined<double>()}, "smap{attribute(keyfield2)}.fval")); + TEST_DO(f.assertInts({ 11, 20, getUndefined<int8_t>()}, "smap{attribute(keyfield1)}.val")); + TEST_DO(f.assertInts({ 10, getUndefined<int8_t>(), getUndefined<int8_t>()}, "smap{attribute(keyfield2)}.val")); +} + +} + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/tests/grouping/CMakeLists.txt b/searchlib/src/tests/grouping/CMakeLists.txt index ea90f27d9f6..e2bb73c57a4 100644 --- a/searchlib/src/tests/grouping/CMakeLists.txt +++ b/searchlib/src/tests/grouping/CMakeLists.txt @@ -4,6 +4,7 @@ vespa_add_executable(searchlib_grouping_test_app TEST grouping_test.cpp DEPENDS searchlib + searchlib_test ) vespa_add_test(NAME searchlib_grouping_test_app COMMAND searchlib_grouping_test_app) vespa_add_executable(searchlib_hyperloglog_test_app TEST diff --git a/searchlib/src/tests/grouping/grouping_test.cpp b/searchlib/src/tests/grouping/grouping_test.cpp index 4cf9eb6f5c6..084f13795f7 100644 --- a/searchlib/src/tests/grouping/grouping_test.cpp +++ b/searchlib/src/tests/grouping/grouping_test.cpp @@ -9,6 +9,8 @@ #include <vespa/searchlib/aggregation/fs4hit.h> #include <vespa/searchlib/aggregation/predicates.h> #include <vespa/searchlib/expression/fixedwidthbucketfunctionnode.h> +#include <vespa/searchlib/test/make_attribute_map_lookup_node.h> +#include <vespa/searchcommon/common/undefinedvalues.h> #include <algorithm> #include <cmath> #include <iostream> @@ -21,6 +23,13 @@ using namespace search; using namespace search::aggregation; using namespace search::attribute; using namespace search::expression; +using search::expression::test::makeAttributeMapLookupNode; + +namespace { + +const int64_t undefinedInteger = getUndefined<int64_t>(); + +} //----------------------------------------------------------------------------- @@ -61,6 +70,14 @@ public: _attr->add(value); return *this; } + AttrBuilder &add(std::vector<T> values) { + DocId ignore; + _attr->addDoc(ignore); + for (T value : values) { + _attr->add(value); + } + return *this; + } AttributeVector::SP sp() const { return _attrSP; } @@ -70,6 +87,9 @@ typedef AttrBuilder<SingleIntegerExtAttribute, int64_t> IntAttrBuilder; typedef AttrBuilder<SingleFloatExtAttribute, double> FloatAttrBuilder; typedef AttrBuilder<SingleStringExtAttribute, const char *> StringAttrBuilder; +using StringArrayAttrBuilder = AttrBuilder<MultiStringExtAttribute, const char *>; +using IntArrayAttrBuilder = AttrBuilder<MultiIntegerExtAttribute, int64_t>; + //----------------------------------------------------------------------------- class ResultBuilder @@ -164,8 +184,10 @@ public: void testFixedWidthBuckets(); void testThatNanIsConverted(); void testNanSorting(); + void testAttributeMapLookup(); int Main() override; private: + void testAggregationSimple(AggregationContext & ctx, const AggregationResult & aggr, const ResultNode & ir, const vespalib::string &name); void testAggregationSimpleSum(AggregationContext & ctx, const AggregationResult & aggr, const ResultNode & ir, const ResultNode & fr, const ResultNode & sr); class CheckAttributeReferences : public vespalib::ObjectOperation, public vespalib::ObjectPredicate { @@ -313,6 +335,18 @@ prepareAggr(const AggregationResult & aggr, ExpressionNode::UP expr, const Resul prepared->setResult(r); return prepared; } + +void Test::testAggregationSimple(AggregationContext & ctx, const AggregationResult & aggr, const ResultNode & ir, const vespalib::string &name) +{ + ExpressionNode::CP clone(aggr); + Grouping request; + request.setRoot(Group().addResult(prepareAggr(aggr, makeAttributeMapLookupNode(name)))); + + Group expect; + expect.addResult(prepareAggr(aggr, makeAttributeMapLookupNode(name), ir)); + EXPECT_TRUE(testAggregation(ctx, request, expect)); +} + void Test::testAggregationSimpleSum(AggregationContext & ctx, const AggregationResult & aggr, const ResultNode & ir, const ResultNode & fr, const ResultNode & sr) { ExpressionNode::CP clone(aggr); @@ -1884,6 +1918,28 @@ Test::testThatNanIsConverted() ASSERT_EQUAL(g.getRank(), g.getRank()); } +void +Test::testAttributeMapLookup() +{ + AggregationContext ctx; + ctx.result().add(0).add(1); + ctx.add(StringArrayAttrBuilder("smap.key").add({"k1", "k2"}).add({"k3", "k4"}).sp()); + ctx.add(IntArrayAttrBuilder("smap.value.weight").add({10, 20}).add({100, 200}).sp()); + ctx.add(StringAttrBuilder("key1").add("k1").add("k4").sp()); + ctx.add(StringAttrBuilder("key2").add("k2").add("k3").sp()); + ctx.add(StringAttrBuilder("key3").add("k3").add("k2").sp()); + testAggregationSimple(ctx, SumAggregationResult(), Int64ResultNode(10 + undefinedInteger), "smap{\"k1\"}.weight"); + testAggregationSimple(ctx, SumAggregationResult(), Int64ResultNode(20 + undefinedInteger), "smap{\"k2\"}.weight"); + testAggregationSimple(ctx, SumAggregationResult(), Int64ResultNode(2 * undefinedInteger), "smap{\"k5\"}.weight"); + testAggregationSimple(ctx, SumAggregationResult(), Int64ResultNode(210), "smap{attribute(key1)}.weight"); + testAggregationSimple(ctx, SumAggregationResult(), Int64ResultNode(120), "smap{attribute(key2)}.weight"); + testAggregationSimple(ctx, SumAggregationResult(), Int64ResultNode(2 * undefinedInteger), "smap{attribute(key3)}.weight"); + testAggregationSimple(ctx, MinAggregationResult(), Int64ResultNode(10), "smap{attribute(key1)}.weight"); + testAggregationSimple(ctx, MinAggregationResult(), Int64ResultNode(20), "smap{attribute(key2)}.weight"); + testAggregationSimple(ctx, MaxAggregationResult(), Int64ResultNode(200), "smap{attribute(key1)}.weight"); + testAggregationSimple(ctx, MaxAggregationResult(), Int64ResultNode(100), "smap{attribute(key2)}.weight"); +} + //----------------------------------------------------------------------------- struct RunDiff { ~RunDiff() { system("diff -u lhs.out rhs.out > diff.txt"); }}; @@ -1916,6 +1972,7 @@ Test::Main() testTopN(); testThatNanIsConverted(); testNanSorting(); + testAttributeMapLookup(); TEST_DONE(); } diff --git a/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp b/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp index a59bbe14404..5ffad122c7d 100644 --- a/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp +++ b/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp @@ -4,6 +4,7 @@ #include "grouping.h" #include <vespa/searchlib/expression/multiargfunctionnode.h> #include <vespa/searchlib/expression/attributenode.h> +#include <vespa/searchlib/expression/attribute_map_lookup_node.h> #include <vespa/searchlib/expression/documentfieldnode.h> using namespace search::expression; diff --git a/searchlib/src/vespa/searchlib/attribute/attributevector.h b/searchlib/src/vespa/searchlib/attribute/attributevector.h index 8cf4079ccfa..54a43bec09e 100644 --- a/searchlib/src/vespa/searchlib/attribute/attributevector.h +++ b/searchlib/src/vespa/searchlib/attribute/attributevector.h @@ -461,13 +461,6 @@ public: virtual uint32_t clearDoc(DocId doc) = 0; virtual largeint_t getDefaultValue() const = 0; - virtual void getEnumValue(const EnumHandle *v, uint32_t *e, uint32_t sz) const = 0; - - uint32_t getEnumValue(EnumHandle eh) const { - uint32_t e(0); - getEnumValue(&eh, &e, 1); - return e; - } // Implements IAttributeVector virtual uint32_t get(DocId doc, EnumHandle *v, uint32_t sz) const override = 0; diff --git a/searchlib/src/vespa/searchlib/attribute/attrvector.h b/searchlib/src/vespa/searchlib/attribute/attrvector.h index 2ba9ed083f0..c0530ee8368 100644 --- a/searchlib/src/vespa/searchlib/attribute/attrvector.h +++ b/searchlib/src/vespa/searchlib/attribute/attrvector.h @@ -34,11 +34,6 @@ private: NumericDirectAttribute & operator=(const NumericDirectAttribute &); bool onLoad() override; typename B::BaseType getFromEnum(EnumHandle e) const override { return _data[e]; } - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { - for (size_t i(0); i < sz; i++) { - e[i] = v[i]; - } - } protected: typedef typename B::BaseType BaseType; typedef typename B::DocId DocId; @@ -153,11 +148,6 @@ protected: ~StringDirectAttribute(); bool findEnum(const char * value, EnumHandle & e) const override; std::vector<EnumHandle> findFoldedEnums(const char *) const override; - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { - for (size_t i(0); i < sz; i++) { - e[i] = v[i]; - } - } void onCommit() override; void onUpdateStat() override { } bool addDoc(DocId & ) override; diff --git a/searchlib/src/vespa/searchlib/attribute/enumattribute.h b/searchlib/src/vespa/searchlib/attribute/enumattribute.h index 26c70d90cfa..993267f79a6 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/enumattribute.h @@ -56,7 +56,6 @@ protected: const EnumStore & getEnumStore() const { return _enumStore; } const EnumStoreBase * getEnumStoreBase() const override { return &_enumStore; } - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { _enumStore.getEnumValue(v, e, sz); } EnumType getFromEnum(EnumHandle e) const override { return _enumStore.getValue(e); } void fillPostings(LoadedVector & loaded) override { (void) loaded; } diff --git a/searchlib/src/vespa/searchlib/attribute/enumstorebase.cpp b/searchlib/src/vespa/searchlib/attribute/enumstorebase.cpp index 61d862b6c4f..142883e54d6 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumstorebase.cpp +++ b/searchlib/src/vespa/searchlib/attribute/enumstorebase.cpp @@ -136,14 +136,6 @@ EnumStoreBase::getAddressSpaceUsage() const } void -EnumStoreBase::getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const -{ - for(size_t i(0); i < sz; i++) { - e[i] = getEnum(Index(v[i])); - } -} - -void EnumStoreBase::transferHoldLists(generation_t generation) { _enumDict->onTransferHoldLists(generation); diff --git a/searchlib/src/vespa/searchlib/attribute/enumstorebase.h b/searchlib/src/vespa/searchlib/attribute/enumstorebase.h index 9bea2a568e1..9fb91169309 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumstorebase.h +++ b/searchlib/src/vespa/searchlib/attribute/enumstorebase.h @@ -273,7 +273,6 @@ public: size_t getMaxEnumOffset() const { return _store.getBufferState(_store.getActiveBufferId(TYPE_ID)).size(); } - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const; uint32_t getRefCount(Index idx) const { return getEntryBase(idx).getRefCount(); } uint32_t getEnum(Index idx) const { return getEntryBase(idx).getEnum(); } void incRefCount(Index idx) { getEntryBase(idx).incRefCount(); } diff --git a/searchlib/src/vespa/searchlib/attribute/multienumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multienumattribute.hpp index 158343ef7c0..0fd40ab027b 100644 --- a/searchlib/src/vespa/searchlib/attribute/multienumattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multienumattribute.hpp @@ -7,8 +7,6 @@ #include "multienumattributesaver.h" #include "load_utils.h" -#include <stdexcept> - namespace search { template <typename B, typename M> @@ -199,18 +197,10 @@ template <typename B, typename M> std::unique_ptr<AttributeSaver> MultiValueEnumAttribute<B, M>::onInitSave(vespalib::stringref fileName) { - { - this->logEnumStoreEvent("reenumerate", "drain"); - EnumModifier enumGuard(this->getEnumModifier()); - this->logEnumStoreEvent("reenumerate", "start"); - this->_enumStore.reEnumerate(); - } - this->logEnumStoreEvent("reenumerate", "complete"); - vespalib::GenerationHandler::Guard guard(this->getGenerationHandler(). - takeGuard()); + this->_enumStore.reEnumerate(); + vespalib::GenerationHandler::Guard guard(this->getGenerationHandler().takeGuard()); return std::make_unique<MultiValueEnumAttributeSaver<WeightedIndex>> - (std::move(guard), this->createAttributeHeader(fileName), this->_mvMapping, - this->_enumStore); + (std::move(guard), this->createAttributeHeader(fileName), this->_mvMapping, this->_enumStore); } } // namespace search diff --git a/searchlib/src/vespa/searchlib/attribute/multinumericattribute.h b/searchlib/src/vespa/searchlib/attribute/multinumericattribute.h index bea676ff0c3..4b951fd7ceb 100644 --- a/searchlib/src/vespa/searchlib/attribute/multinumericattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/multinumericattribute.h @@ -43,12 +43,6 @@ private: T getFromEnum(EnumHandle e) const override; bool findEnum(T value, EnumHandle & e) const override; - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { - (void) v; - (void) e; - (void) sz; - } - protected: typedef typename B::generation_t generation_t; diff --git a/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.cpp b/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.cpp index 1dc95c42de8..e9743e3e86d 100644 --- a/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.cpp +++ b/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.cpp @@ -139,11 +139,6 @@ NotImplementedAttribute::getEnum(DocId) const { return 0; } -void -NotImplementedAttribute::getEnumValue(const EnumHandle *, uint32_t *, uint32_t) const { - notImplemented(); -} - bool NotImplementedAttribute::addDoc(DocId &) { notImplemented(); diff --git a/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.h b/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.h index cbd2ff162b2..4552a24ec2e 100644 --- a/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.h +++ b/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.h @@ -33,7 +33,6 @@ struct NotImplementedAttribute : AttributeVector { uint32_t clearDoc(DocId) override; int64_t getDefaultValue() const override; uint32_t getEnum(DocId) const override; - void getEnumValue(const EnumHandle *, uint32_t *, uint32_t) const override; bool addDoc(DocId &) override; void onAddDocs(DocId lidLimit) override; diff --git a/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp index c7299cd71d9..cc9b0346690 100644 --- a/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp @@ -309,15 +309,8 @@ template <typename B> std::unique_ptr<AttributeSaver> SingleValueEnumAttribute<B>::onInitSave(vespalib::stringref fileName) { - { - this->logEnumStoreEvent("reenumerate", "drain"); - EnumModifier enumGuard(this->getEnumModifier()); - this->logEnumStoreEvent("reenumerate", "start"); - this->_enumStore.reEnumerate(); - } - this->logEnumStoreEvent("reenumerate", "complete"); - vespalib::GenerationHandler::Guard guard(this->getGenerationHandler(). - takeGuard()); + this->_enumStore.reEnumerate(); + vespalib::GenerationHandler::Guard guard(this->getGenerationHandler().takeGuard()); return std::make_unique<SingleValueEnumAttributeSaver> (std::move(guard), this->createAttributeHeader(fileName), diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.h b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.h index 06d1068b21a..81fda8b92fc 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.h @@ -125,11 +125,6 @@ public: largeint_t getInt(DocId doc) const override { return static_cast<largeint_t>(getFast(doc)); } - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { - (void) v; - (void) e; - (void) sz; - } double getFloat(DocId doc) const override { return static_cast<double>(_data[doc]); } diff --git a/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.h b/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.h index f5f666bd89f..d5b65da08fa 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.h @@ -141,11 +141,6 @@ public: largeint_t getInt(DocId doc) const override { return static_cast<largeint_t>(getFast(doc)); } - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { - (void) v; - (void) e; - (void) sz; - } double getFloat(DocId doc) const override { return static_cast<double>(getFast(doc)); } diff --git a/searchlib/src/vespa/searchlib/common/identifiable.h b/searchlib/src/vespa/searchlib/common/identifiable.h index 5a64e29ddf3..35e49b5cddf 100644 --- a/searchlib/src/vespa/searchlib/common/identifiable.h +++ b/searchlib/src/vespa/searchlib/common/identifiable.h @@ -148,6 +148,7 @@ #define CID_search_expression_AggregationRefNode SEARCHLIB_CID(142) #define CID_search_expression_NormalizeSubjectFunctionNode SEARCHLIB_CID(143) #define CID_search_expression_DebugWaitFunctionNode SEARCHLIB_CID(144) +#define CID_search_expression_AttributeMapLookupNode SEARCHLIB_CID(145) #define CID_search_QueryNode SEARCHLIB_CID(150) #define CID_search_Query SEARCHLIB_CID(151) diff --git a/searchlib/src/vespa/searchlib/engine/transportserver.cpp b/searchlib/src/vespa/searchlib/engine/transportserver.cpp index c5e59024c31..bc739a7bf48 100644 --- a/searchlib/src/vespa/searchlib/engine/transportserver.cpp +++ b/searchlib/src/vespa/searchlib/engine/transportserver.cpp @@ -7,6 +7,7 @@ #include <vespa/fnet/connection.h> #include <vespa/fnet/connector.h> #include <vespa/fnet/iexecutable.h> +#include <vespa/vespalib/net/crypto_engine.h> #include <vespa/log/log.h> LOG_SETUP(".engine.transportserver"); @@ -358,7 +359,7 @@ TransportServer::TransportServer(SearchServer &searchServer, : _searchServer(searchServer), _docsumServer(docsumServer), _monitorServer(monitorServer), - _transport(), + _transport(std::make_shared<vespalib::NullCryptoEngine>(), 1), // disable encryption _ready(false), _failed(false), _doListen(true), diff --git a/searchlib/src/vespa/searchlib/engine/transportserver.h b/searchlib/src/vespa/searchlib/engine/transportserver.h index 691f6fbe791..67d373d5940 100644 --- a/searchlib/src/vespa/searchlib/engine/transportserver.h +++ b/searchlib/src/vespa/searchlib/engine/transportserver.h @@ -301,13 +301,6 @@ public: void setTCPNoDelay(bool noDelay) { _transport.SetTCPNoDelay(noDelay); } /** - * Enable or disable the use of a Q for throughput between search thread and network thread. - * - * @param directWrite bypasses Q - **/ - void setDirectWrite(bool directWrite) { _transport.SetDirectWrite(directWrite); } - - /** * Set a limit on how long a connection may be idle before closing it. * * @param millisecs max idle time in milliseconds diff --git a/searchlib/src/vespa/searchlib/expression/CMakeLists.txt b/searchlib/src/vespa/searchlib/expression/CMakeLists.txt index 1b7a26bf621..652fa5a3b01 100644 --- a/searchlib/src/vespa/searchlib/expression/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/expression/CMakeLists.txt @@ -1,6 +1,7 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_library(searchlib_expression OBJECT SOURCES + attribute_map_lookup_node.cpp attributenode.cpp attributeresult.cpp enumattributeresult.cpp diff --git a/searchlib/src/vespa/searchlib/expression/attribute_map_lookup_node.cpp b/searchlib/src/vespa/searchlib/expression/attribute_map_lookup_node.cpp new file mode 100644 index 00000000000..8a851b043aa --- /dev/null +++ b/searchlib/src/vespa/searchlib/expression/attribute_map_lookup_node.cpp @@ -0,0 +1,405 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "attribute_map_lookup_node.h" +#include <vespa/vespalib/stllike/asciistream.h> +#include <vespa/vespalib/util/exceptions.h> +#include <vespa/searchcommon/attribute/attributecontent.h> +#include <vespa/searchcommon/attribute/iattributecontext.h> +#include <vespa/searchcommon/common/undefinedvalues.h> + +using search::attribute::AttributeContent; +using search::attribute::IAttributeVector; +using search::attribute::BasicType; +using search::attribute::getUndefined; +using vespalib::Deserializer; +using vespalib::Serializer; +using EnumHandle = IAttributeVector::EnumHandle; + +namespace search::expression { + +IMPLEMENT_EXPRESSIONNODE(AttributeMapLookupNode, AttributeNode); + +class AttributeMapLookupNode::KeyHandler +{ +protected: + const IAttributeVector &_attribute; + + KeyHandler(const IAttributeVector &attribute) + : _attribute(attribute) + { + } +public: + static uint32_t noKeyIdx() { return std::numeric_limits<uint32_t>::max(); } + virtual ~KeyHandler() = default; + virtual uint32_t handle(DocId docId) = 0; +}; + +namespace { + +vespalib::string indirectKeyMarker("attribute("); + +class BadKeyHandler : public AttributeMapLookupNode::KeyHandler +{ +public: + BadKeyHandler(const IAttributeVector &attribute) + : KeyHandler(attribute) + { + } + uint32_t handle(DocId) override { return noKeyIdx(); } +}; + +template <typename KeyType> +KeyType convertKey(const IAttributeVector &, const vespalib::string &key) +{ + KeyType ret; + vespalib::asciistream is(key); + is >> ret; + return ret; +} + +template <> +vespalib::string convertKey<vespalib::string>(const IAttributeVector &, const vespalib::string &key) +{ + return key; +} + +template <> +EnumHandle convertKey<EnumHandle>(const IAttributeVector &attribute, const vespalib::string &key) +{ + EnumHandle ret; + if (!attribute.findEnum(key.c_str(), ret)) { + ret = EnumHandle(); + } + return ret; +} + +template <typename T, typename KeyType = T> +class KeyHandlerT : public AttributeMapLookupNode::KeyHandler +{ + AttributeContent<T> _keys; + KeyType _key; + +public: + KeyHandlerT(const IAttributeVector &attribute, const vespalib::string &key) + : KeyHandler(attribute), + _keys(), + _key(convertKey<KeyType>(attribute, key)) + { + } + ~KeyHandlerT() override; + uint32_t handle(DocId docId) override { + _keys.fill(_attribute, docId); + for (uint32_t i = 0; i < _keys.size(); ++i) { + if (_key == _keys[i]) { + return i; + } + } + return noKeyIdx(); + } +}; + +template <typename T, typename KeyType> +KeyHandlerT<T,KeyType>::~KeyHandlerT() +{ +} + +using IntegerKeyHandler = KeyHandlerT<IAttributeVector::largeint_t>; +using FloatKeyHandler = KeyHandlerT<double>; +using StringKeyHandler = KeyHandlerT<const char *, vespalib::string>; +using EnumKeyHandler = KeyHandlerT<EnumHandle>; + +template <typename T> +bool +matchingKey(T lhs, T rhs) +{ + return lhs == rhs; +} + +template <> +bool +matchingKey<const char *>(const char *lhs, const char *rhs) +{ + return (strcmp(lhs, rhs) == 0); +} + +template <typename T> +class IndirectKeyHandlerT : public AttributeMapLookupNode::KeyHandler +{ + const IAttributeVector &_keySourceAttribute; + AttributeContent<T> _keys; + +public: + IndirectKeyHandlerT(const IAttributeVector &attribute, const IAttributeVector &keySourceAttribute) + : KeyHandler(attribute), + _keySourceAttribute(keySourceAttribute), + _keys() + { + } + ~IndirectKeyHandlerT() override; + uint32_t handle(DocId docId) override { + T key = T(); + _keySourceAttribute.get(docId, &key, 1); + _keys.fill(_attribute, docId); + for (uint32_t i = 0; i < _keys.size(); ++i) { + if (matchingKey(key, _keys[i])) { + return i; + } + } + return noKeyIdx(); + } +}; + +template <typename T> +IndirectKeyHandlerT<T>::~IndirectKeyHandlerT() +{ +} + +using IndirectIntegerKeyHandler = IndirectKeyHandlerT<IAttributeVector::largeint_t>; +using IndirectFloatKeyHandler = IndirectKeyHandlerT<double>; +using IndirectStringKeyHandler = IndirectKeyHandlerT<const char *>; + +class ValueHandler : public AttributeNode::Handler +{ +protected: + std::unique_ptr<AttributeMapLookupNode::KeyHandler> _keyHandler; + const IAttributeVector &_attribute; + ValueHandler(std::unique_ptr<AttributeMapLookupNode::KeyHandler> keyHandler, const IAttributeVector &attribute) + : _keyHandler(std::move(keyHandler)), + _attribute(attribute) + { + } +}; + +template <typename T, typename ResultNodeType> +class ValueHandlerT : public ValueHandler +{ + AttributeContent<T> _values; + ResultNodeType &_result; + T _undefinedValue; +public: + ValueHandlerT(std::unique_ptr<AttributeMapLookupNode::KeyHandler> keyHandler, const IAttributeVector &attribute, ResultNodeType &result, T undefinedValue) + : ValueHandler(std::move(keyHandler), attribute), + _values(), + _result(result), + _undefinedValue(undefinedValue) + { + } + void handle(const AttributeResult & r) override { + uint32_t docId = r.getDocId(); + uint32_t keyIdx = _keyHandler->handle(docId); + if (keyIdx != AttributeMapLookupNode::KeyHandler::noKeyIdx()) { + _values.fill(_attribute, docId); + if (keyIdx < _values.size()) { + _result = _values[keyIdx]; + return; + } + } + _result = _undefinedValue; + } +}; + +template <typename ResultNodeType> +using IntegerValueHandler = ValueHandlerT<IAttributeVector::largeint_t, ResultNodeType>; +using FloatValueHandler = ValueHandlerT<double, FloatResultNode>; +using StringValueHandler = ValueHandlerT<const char *, StringResultNode>; +using EnumValueHandler = ValueHandlerT<EnumHandle, EnumResultNode>; + +const IAttributeVector *findAttribute(const search::attribute::IAttributeContext &attrCtx, bool useEnumOptimization, const vespalib::string &name) +{ + const IAttributeVector *attribute = useEnumOptimization ? attrCtx.getAttributeStableEnum(name) : attrCtx.getAttribute(name); + if (attribute == nullptr) { + throw std::runtime_error(vespalib::make_string("Failed locating attribute vector '%s'", name.c_str())); + } + return attribute; +} + +IAttributeVector::largeint_t getUndefinedValue(BasicType::Type basicType) +{ + switch (basicType) { + case BasicType::INT8: + return getUndefined<int8_t>(); + case BasicType::INT16: + return getUndefined<int16_t>(); + case BasicType::INT32: + return getUndefined<int32_t>(); + case BasicType::INT64: + return getUndefined<int64_t>(); + break; + default: + return 0; + } +} + +} + +AttributeMapLookupNode::AttributeMapLookupNode() + : AttributeNode(), + _keyAttributeName(), + _valueAttributeName(), + _key(), + _keySourceAttributeName(), + _keyAttribute(nullptr), + _keySourceAttribute(nullptr) +{ +} + +AttributeMapLookupNode::AttributeMapLookupNode(const AttributeMapLookupNode &) = default; + +AttributeMapLookupNode::AttributeMapLookupNode(vespalib::stringref name, vespalib::stringref keyAttributeName, vespalib::stringref valueAttributeName, vespalib::stringref key, vespalib::stringref keySourceAttributeName) + : AttributeNode(name), + _keyAttributeName(keyAttributeName), + _valueAttributeName(valueAttributeName), + _key(key), + _keySourceAttributeName(keySourceAttributeName), + _keyAttribute(nullptr), + _keySourceAttribute(nullptr) +{ +} + +AttributeMapLookupNode::~AttributeMapLookupNode() = default; + +AttributeMapLookupNode & +AttributeMapLookupNode::operator=(const AttributeMapLookupNode &rhs) = default; + +template <typename ResultNodeType> +void +AttributeMapLookupNode::prepareIntValues(std::unique_ptr<KeyHandler> keyHandler, const IAttributeVector &attribute, IAttributeVector::largeint_t undefinedValue) +{ + auto resultNode = std::make_unique<ResultNodeType>(); + _handler = std::make_unique<IntegerValueHandler<ResultNodeType>>(std::move(keyHandler), attribute, *resultNode, undefinedValue); + setResultType(std::move(resultNode)); +} + +std::unique_ptr<AttributeMapLookupNode::KeyHandler> +AttributeMapLookupNode::makeKeyHandlerHelper() +{ + const IAttributeVector &attribute = *_keyAttribute; + if (_keySourceAttribute != nullptr) { + const IAttributeVector &keySourceAttribute = *_keySourceAttribute; + if (attribute.isIntegerType() && keySourceAttribute.isIntegerType()) { + return std::make_unique<IndirectIntegerKeyHandler>(attribute, keySourceAttribute); + } else if (attribute.isFloatingPointType() && keySourceAttribute.isFloatingPointType()) { + return std::make_unique<IndirectFloatKeyHandler>(attribute, keySourceAttribute); + } else if (attribute.isStringType() && keySourceAttribute.isStringType()) { + return std::make_unique<IndirectStringKeyHandler>(attribute, keySourceAttribute); + } else { + return std::make_unique<BadKeyHandler>(attribute); + } + } + if (attribute.hasEnum() && _useEnumOptimization) { + return std::make_unique<EnumKeyHandler>(attribute, _key); + } else if (attribute.isIntegerType()) { + return std::make_unique<IntegerKeyHandler>(attribute, _key); + } else if (attribute.isFloatingPointType()) { + return std::make_unique<FloatKeyHandler>(attribute, _key); + } else if (attribute.isStringType()) { + return std::make_unique<StringKeyHandler>(attribute, _key); + } else { + return std::make_unique<BadKeyHandler>(attribute); + } +} + +std::unique_ptr<AttributeMapLookupNode::KeyHandler> +AttributeMapLookupNode::makeKeyHandler() +{ + try { + return makeKeyHandlerHelper(); + } catch (const vespalib::IllegalArgumentException &) { + return std::make_unique<BadKeyHandler>(*_keyAttribute); + } +} + +void +AttributeMapLookupNode::onPrepare(bool preserveAccurateTypes) +{ + auto keyHandler = makeKeyHandler(); + const IAttributeVector * attribute = _scratchResult->getAttribute(); + if (attribute != nullptr) { + BasicType::Type basicType = attribute->getBasicType(); + if (attribute->isIntegerType()) { + IAttributeVector::largeint_t undefinedValue = getUndefinedValue(basicType); + if (preserveAccurateTypes) { + switch (basicType) { + case BasicType::INT8: + prepareIntValues<Int8ResultNode>(std::move(keyHandler), *attribute, undefinedValue); + break; + case BasicType::INT16: + prepareIntValues<Int16ResultNode>(std::move(keyHandler), *attribute, undefinedValue); + break; + case BasicType::INT32: + prepareIntValues<Int32ResultNode>(std::move(keyHandler), *attribute, undefinedValue); + break; + case BasicType::INT64: + prepareIntValues<Int64ResultNode>(std::move(keyHandler), *attribute, undefinedValue); + break; + default: + throw std::runtime_error("This is no valid integer attribute " + attribute->getName()); + break; + } + } else { + prepareIntValues<Int64ResultNode>(std::move(keyHandler), *attribute, undefinedValue); + } + } else if (attribute->isFloatingPointType()) { + auto resultNode = std::make_unique<FloatResultNode>(); + _handler = std::make_unique<FloatValueHandler>(std::move(keyHandler), *attribute, *resultNode, getUndefined<double>()); + setResultType(std::move(resultNode)); + } else if (attribute->isStringType()) { + if (_useEnumOptimization) { + auto resultNode = std::make_unique<EnumResultNode>(); + _handler = std::make_unique<EnumValueHandler>(std::move(keyHandler), *attribute, *resultNode, EnumHandle()); + setResultType(std::move(resultNode)); + } else { + auto resultNode = std::make_unique<StringResultNode>(); + _handler = std::make_unique<StringValueHandler>(std::move(keyHandler), *attribute, *resultNode, ""); + setResultType(std::move(resultNode)); + } + } else { + throw std::runtime_error(vespalib::make_string("Can not deduce correct resultclass for attribute vector '%s'", + attribute->getName().c_str())); + } + } +} + +void +AttributeMapLookupNode::cleanup() +{ + _keyAttribute = nullptr; + _keySourceAttribute = nullptr; + AttributeNode::cleanup(); +} + +void +AttributeMapLookupNode::wireAttributes(const search::attribute::IAttributeContext &attrCtx) +{ + auto valueAttribute = findAttribute(attrCtx, _useEnumOptimization, _valueAttributeName); + _hasMultiValue = false; + _scratchResult = std::make_unique<AttributeResult>(valueAttribute, 0); + _keyAttribute = findAttribute(attrCtx, _useEnumOptimization, _keyAttributeName); + if (!_keySourceAttributeName.empty()) { + _keySourceAttribute = findAttribute(attrCtx, false, _keySourceAttributeName); + } +} + +Serializer & AttributeMapLookupNode::onSerialize(Serializer & os) const +{ + AttributeNode::onSerialize(os); + return os << _keyAttributeName << _valueAttributeName << _key << _keySourceAttributeName; +} + +Deserializer & AttributeMapLookupNode::onDeserialize(Deserializer & is) +{ + AttributeNode::onDeserialize(is); + return is >> _keyAttributeName >> _valueAttributeName >> _key >> _keySourceAttributeName; +} + +void +AttributeMapLookupNode::visitMembers(vespalib::ObjectVisitor &visitor) const +{ + AttributeNode::visitMembers(visitor); + visit(visitor, "keyAttributeName", _keyAttributeName); + visit(visitor, "keySourceAttributeName", _keySourceAttributeName); + visit(visitor, "valueAttributeName", _valueAttributeName); + visit(visitor, "key", _key); +} + +} diff --git a/searchlib/src/vespa/searchlib/expression/attribute_map_lookup_node.h b/searchlib/src/vespa/searchlib/expression/attribute_map_lookup_node.h new file mode 100644 index 00000000000..2f9c6328969 --- /dev/null +++ b/searchlib/src/vespa/searchlib/expression/attribute_map_lookup_node.h @@ -0,0 +1,45 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "attributenode.h" + +namespace search::expression { + +/** + * Extract map value from attribute for the map key specified in the + * grouping expression. + */ +class AttributeMapLookupNode : public AttributeNode +{ +public: + using IAttributeVector = search::attribute::IAttributeVector; + class KeyHandler; +private: + vespalib::string _keyAttributeName; + vespalib::string _valueAttributeName; + vespalib::string _key; + vespalib::string _keySourceAttributeName; + const IAttributeVector *_keyAttribute; + const IAttributeVector *_keySourceAttribute; + + template <typename ResultNodeType> + void prepareIntValues(std::unique_ptr<KeyHandler> keyHandler, const IAttributeVector &attribute, IAttributeVector::largeint_t undefinedValue); + std::unique_ptr<KeyHandler> makeKeyHandlerHelper(); + std::unique_ptr<KeyHandler> makeKeyHandler(); + void cleanup() override; + void wireAttributes(const search::attribute::IAttributeContext & attrCtx) override; + void onPrepare(bool preserveAccurateTypes) override; +public: + DECLARE_NBO_SERIALIZE; + DECLARE_EXPRESSIONNODE(AttributeMapLookupNode); + AttributeMapLookupNode(); + AttributeMapLookupNode(vespalib::stringref name, vespalib::stringref keyAttributeName, vespalib::stringref valueAttributeName, vespalib::stringref key, vespalib::stringref keySourceAttributeName); + AttributeMapLookupNode(const AttributeMapLookupNode &); + AttributeMapLookupNode(AttributeMapLookupNode &&) = delete; + ~AttributeMapLookupNode() override; + AttributeMapLookupNode &operator=(const AttributeMapLookupNode &rhs); + AttributeMapLookupNode &operator=(AttributeMapLookupNode &&rhs) = delete; + void visitMembers(vespalib::ObjectVisitor &visitor) const override; +}; + +} diff --git a/searchlib/src/vespa/searchlib/expression/attributenode.h b/searchlib/src/vespa/searchlib/expression/attributenode.h index 3cbccd32e60..472267f4b5c 100644 --- a/searchlib/src/vespa/searchlib/expression/attributenode.h +++ b/searchlib/src/vespa/searchlib/expression/attributenode.h @@ -55,7 +55,7 @@ public: void useEnumOptimization(bool use=true) { _useEnumOptimization = use; } bool hasMultiValue() const { return _hasMultiValue; } -protected: +public: class Handler { public: @@ -68,7 +68,7 @@ private: class StringHandler; class EnumHandler; protected: - void cleanup(); + virtual void cleanup(); void wireAttributes(const search::attribute::IAttributeContext & attrCtx) override; void onPrepare(bool preserveAccurateTypes) override; bool onExecute() const override; diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp index b05d5fb4e54..5cd6c479d24 100644 --- a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp +++ b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp @@ -399,7 +399,13 @@ const uint32_t HeapSize::DEFAULT_VALUE(100); uint32_t HeapSize::lookup(const Properties &props) { - return lookupUint32(props, NAME, DEFAULT_VALUE); + return lookup(props, DEFAULT_VALUE); +} + +uint32_t +HeapSize::lookup(const Properties &props, uint32_t defaultValue) +{ + return lookupUint32(props, NAME, defaultValue); } const vespalib::string ArraySize::NAME("vespa.hitcollector.arraysize"); diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.h b/searchlib/src/vespa/searchlib/fef/indexproperties.h index 68bed502121..8b78e347a90 100644 --- a/searchlib/src/vespa/searchlib/fef/indexproperties.h +++ b/searchlib/src/vespa/searchlib/fef/indexproperties.h @@ -320,6 +320,7 @@ namespace hitcollector { static const vespalib::string NAME; static const uint32_t DEFAULT_VALUE; static uint32_t lookup(const Properties &props); + static uint32_t lookup(const Properties &props, uint32_t defaultValue); }; /** diff --git a/searchlib/src/vespa/searchlib/test/CMakeLists.txt b/searchlib/src/vespa/searchlib/test/CMakeLists.txt index b4fb895c0e2..1231a99920e 100644 --- a/searchlib/src/vespa/searchlib/test/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/test/CMakeLists.txt @@ -3,6 +3,7 @@ vespa_add_library(searchlib_test SOURCES document_weight_attribute_helper.cpp initrange.cpp + make_attribute_map_lookup_node.cpp mock_attribute_context.cpp mock_attribute_manager.cpp searchiteratorverifier.cpp diff --git a/searchlib/src/vespa/searchlib/test/make_attribute_map_lookup_node.cpp b/searchlib/src/vespa/searchlib/test/make_attribute_map_lookup_node.cpp new file mode 100644 index 00000000000..6f717d99d2c --- /dev/null +++ b/searchlib/src/vespa/searchlib/test/make_attribute_map_lookup_node.cpp @@ -0,0 +1,38 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "make_attribute_map_lookup_node.h" +#include <vespa/searchlib/expression/attribute_map_lookup_node.h> +#include <vespa/vespalib/stllike/asciistream.h> + +namespace search::expression::test { + +namespace { + +vespalib::string indirectKeyMarker("attribute("); + +} + +std::unique_ptr<AttributeNode> +makeAttributeMapLookupNode(const vespalib::string attributeName) +{ + vespalib::asciistream keyName; + vespalib::asciistream valueName; + auto leftBracePos = attributeName.find('{'); + auto baseName = attributeName.substr(0, leftBracePos); + auto rightBracePos = attributeName.rfind('}'); + keyName << baseName << ".key"; + valueName << baseName << ".value" << attributeName.substr(rightBracePos + 1); + if (rightBracePos != vespalib::string::npos && rightBracePos > leftBracePos) { + if (attributeName[leftBracePos + 1] == '"' && attributeName[rightBracePos - 1] == '"') { + vespalib::string key = attributeName.substr(leftBracePos + 2, rightBracePos - leftBracePos - 3); + return std::make_unique<AttributeMapLookupNode>(attributeName, keyName.str(), valueName.str(), key, ""); + } else if (attributeName.substr(leftBracePos + 1, indirectKeyMarker.size()) == indirectKeyMarker && attributeName[rightBracePos - 1] == ')') { + auto startPos = leftBracePos + 1 + indirectKeyMarker.size(); + vespalib::string keySourceAttributeName = attributeName.substr(startPos, rightBracePos - 1 - startPos); + return std::make_unique<AttributeMapLookupNode>(attributeName, keyName.str(), valueName.str(), "", keySourceAttributeName); + } + } + return std::unique_ptr<AttributeNode>(); +} + +} diff --git a/searchlib/src/vespa/searchlib/test/make_attribute_map_lookup_node.h b/searchlib/src/vespa/searchlib/test/make_attribute_map_lookup_node.h new file mode 100644 index 00000000000..3434c8f2ae3 --- /dev/null +++ b/searchlib/src/vespa/searchlib/test/make_attribute_map_lookup_node.h @@ -0,0 +1,14 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <vespa/vespalib/stllike/string.h> +#include <memory> + +namespace search::expression { class AttributeNode; } + +namespace search::expression::test { + +std::unique_ptr<AttributeNode> +makeAttributeMapLookupNode(const vespalib::string attributeName); + +} diff --git a/searchlib/src/vespa/searchlib/transactionlog/domain.cpp b/searchlib/src/vespa/searchlib/transactionlog/domain.cpp index 1caff132779..fc9518ccf1b 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/domain.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/domain.cpp @@ -60,13 +60,13 @@ Domain::Domain(const string &domainName, const string & baseDir, Executor & comm } _sessionExecutor.sync(); if (_parts.empty() || _parts.crbegin()->second->isClosed()) { - _parts[lastPart].reset(new DomainPart(_name, dir(), lastPart, _defaultCrcType, _fileHeaderContext, false)); + _parts[lastPart] = std::make_shared<DomainPart>(_name, dir(), lastPart, _defaultCrcType, _fileHeaderContext, false); vespalib::File::sync(dir()); } } void Domain::addPart(int64_t partId, bool isLastPart) { - DomainPart::SP dp(new DomainPart(_name, dir(), partId, _defaultCrcType, _fileHeaderContext, isLastPart)); + auto dp = std::make_shared<DomainPart>(_name, dir(), partId, _defaultCrcType, _fileHeaderContext, isLastPart); if (dp->size() == 0) { // Only last domain part is allowed to be truncated down to // empty size. @@ -199,7 +199,7 @@ Domain::triggerSyncNow() if (!_pendingSync) { _pendingSync = true; DomainPart::SP dp(_parts.rbegin()->second); - _commitExecutor.execute(Sync::UP(new Sync(_syncMonitor, dp, _pendingSync))); + _commitExecutor.execute(std::make_unique<Sync>(_syncMonitor, dp, _pendingSync)); } } @@ -290,7 +290,7 @@ void Domain::commit(const Packet & packet) triggerSyncNow(); waitPendingSync(_syncMonitor, _pendingSync); dp->close(); - dp.reset(new DomainPart(_name, dir(), entry.serial(), _defaultCrcType, _fileHeaderContext, false)); + dp = std::make_shared<DomainPart>(_name, dir(), entry.serial(), _defaultCrcType, _fileHeaderContext, false); { LockGuard guard(_lock); _parts[entry.serial()] = dp; @@ -322,15 +322,16 @@ bool Domain::erase(SerialNum to) } int Domain::visit(const Domain::SP & domain, SerialNum from, SerialNum to, - FRT_Supervisor & supervisor, FNET_Connection *conn) + std::unique_ptr<Session::Destination> dest) { assert(this == domain.get()); cleanSessions(); SerialNumRange range(from, to); - Session * session = new Session(_sessionId++, range, domain, supervisor, conn); + auto session = std::make_shared<Session>(_sessionId++, range, domain, std::move(dest)); + int id = session->id(); LockGuard guard(_sessionLock); - _sessions[session->id()] = Session::SP(session); - return session->id(); + _sessions[id] = std::move(session); + return id; } int Domain::startSession(int sessionId) diff --git a/searchlib/src/vespa/searchlib/transactionlog/domain.h b/searchlib/src/vespa/searchlib/transactionlog/domain.h index c1ff9157a6f..c0ee484926c 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/domain.h +++ b/searchlib/src/vespa/searchlib/transactionlog/domain.h @@ -51,7 +51,7 @@ public: bool erase(SerialNum to); void commit(const Packet & packet); - int visit(const Domain::SP & self, SerialNum from, SerialNum to, FRT_Supervisor & supervisor, FNET_Connection *conn); + int visit(const Domain::SP & self, SerialNum from, SerialNum to, std::unique_ptr<Session::Destination> dest); SerialNum begin() const; SerialNum end() const; diff --git a/searchlib/src/vespa/searchlib/transactionlog/session.cpp b/searchlib/src/vespa/searchlib/transactionlog/session.cpp index cbcbc68fdff..e703c32484f 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/session.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/session.cpp @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "session.h" #include "domain.h" -#include <vespa/fnet/frt/supervisor.h> #include <vespa/fastlib/io/bufferedfile.h> #include <vespa/log/log.h> @@ -11,14 +10,10 @@ using vespalib::LockGuard; namespace search::transactionlog { -namespace { - const double NEVER(-1.0); -} - vespalib::Executor::Task::UP Session::createTask(const Session::SP & session) { - return Task::UP(new VisitTask(session)); + return std::make_unique<VisitTask>(session); } Session::VisitTask::VisitTask(const Session::SP & session) @@ -86,7 +81,7 @@ Session::visitOnly() } bool Session::finished() const { - return _finished || (_connection->GetState() != FNET_Connection::FNET_CONNECTED); + return _finished || ! _destination->connected(); } void @@ -99,95 +94,31 @@ Session::finalize() _finished = true; } -int32_t -Session::rpc(FRT_RPCRequest * req) -{ - int32_t retval(-7); - LOG(debug, "rpc %s starting.", req->GetMethodName()); - FRT_Supervisor::InvokeSync(_supervisor.GetTransport(), _connection, req, NEVER); - if (req->GetErrorCode() == FRTE_NO_ERROR) { - retval = (req->GetReturn()->GetValue(0)._intval32); - LOG(debug, "rpc %s = %d\n", req->GetMethodName(), retval); - } else if (req->GetErrorCode() == FRTE_RPC_TIMEOUT) { - LOG(warning, "rpc %s timed out. Will allow to continue: error(%d): %s\n", req->GetMethodName(), req->GetErrorCode(), req->GetErrorMessage()); - retval = -req->GetErrorCode(); - } else { - if (req->GetErrorCode() != FRTE_RPC_CONNECTION) { - LOG(warning, "rpc %s: error(%d): %s\n", req->GetMethodName(), req->GetErrorCode(), req->GetErrorMessage()); - } - retval = -req->GetErrorCode(); - _ok = false; - } - return retval; -} - -void -Session::RequestDone(FRT_RPCRequest * req) -{ - _ok = (req->GetErrorCode() == FRTE_NO_ERROR); - if (req->GetErrorCode() != FRTE_NO_ERROR) { - LOG(warning, "rpcAsync failed %s: error(%d): %s\n", req->GetMethodName(), req->GetErrorCode(), req->GetErrorMessage()); - } else { - int32_t retval = req->GetReturn()->GetValue(0)._intval32; - if (retval != RPC::OK) { - LOG(error, "Return value != OK in RequestDone for method '%s'", req->GetMethodName()); - } - } - req->SubRef(); -} - Session::Session(int sId, const SerialNumRange & r, const Domain::SP & d, - FRT_Supervisor & supervisor, FNET_Connection *conn) : - _supervisor(supervisor), - _connection(conn), + std::unique_ptr<Destination> destination) : + _destination(std::move(destination)), _domain(d), _range(r), _id(sId), - _ok(true), _visitRunning(false), _inSync(false), _finished(false), _startTime() { - _connection->AddRef(); } -Session::~Session() -{ - _connection->SubRef(); -} +Session::~Session() = default; bool Session::send(const Packet & packet) { - FRT_RPCRequest *req = _supervisor.AllocRPCRequest(); - req->SetMethodName("visitCallback"); - req->GetParams()->AddString(_domain->name().c_str()); - req->GetParams()->AddInt32(id()); - req->GetParams()->AddData(packet.getHandle().c_str(), packet.getHandle().size()); - return send(req); -} - -bool -Session::send(FRT_RPCRequest * req) -{ - int32_t retval = rpc(req); - if ( ! ((retval == RPC::OK) || (retval == FRTE_RPC_CONNECTION)) ) { - LOG(error, "Return value != OK(%d) in send for method 'visitCallback'.", retval); - } - req->SubRef(); - - return (retval == RPC::OK); + return _destination->send(_id, _domain->name(), packet); } bool Session::sendDone() { - FRT_RPCRequest *req = _supervisor.AllocRPCRequest(); - req->SetMethodName("eofCallback"); - req->GetParams()->AddString(_domain->name().c_str()); - req->GetParams()->AddInt32(id()); - bool retval(send(req)); + bool retval = _destination->sendDone(_id, _domain->name()); _inSync = true; return retval; } diff --git a/searchlib/src/vespa/searchlib/transactionlog/session.h b/searchlib/src/vespa/searchlib/transactionlog/session.h index 29038ec5290..bf35d83c000 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/session.h +++ b/searchlib/src/vespa/searchlib/transactionlog/session.h @@ -4,9 +4,9 @@ #include "common.h" #include <vespa/vespalib/util/executor.h> #include <vespa/vespalib/util/sync.h> -#include <vespa/fnet/frt/invoker.h> #include <chrono> #include <deque> +#include <atomic> class FastOS_FileInterface; @@ -16,22 +16,29 @@ class Domain; class DomainPart; using DomainSP = std::shared_ptr<Domain>; -class Session : public FRT_IRequestWait +class Session { private: using Task = vespalib::Executor::Task; using time_point = std::chrono::time_point<std::chrono::steady_clock>; public: + class Destination { + public: + virtual ~Destination() {} + virtual bool send(int32_t id, const vespalib::string & domain, const Packet & packet) = 0; + virtual bool sendDone(int32_t id, const vespalib::string & domain) = 0; + virtual bool connected() const = 0; + virtual bool ok() const = 0; + }; typedef std::shared_ptr<Session> SP; Session(const Session &) = delete; Session & operator = (const Session &) = delete; - Session(int sId, const SerialNumRange & r, const DomainSP & d, FRT_Supervisor & supervisor, FNET_Connection *conn); + Session(int sId, const SerialNumRange & r, const DomainSP & d, std::unique_ptr<Destination> destination); ~Session(); const SerialNumRange & range() const { return _range; } int id() const { return _id; } bool inSync() const { return _inSync; } - bool ok() const { return _ok; } bool finished() const; static Task::UP createTask(const Session::SP & session); void setStartTime(time_point startTime) { _startTime = startTime; } @@ -47,8 +54,7 @@ private: Session::SP _session; }; - bool send(FRT_RPCRequest * req); - void RequestDone(FRT_RPCRequest *req) override; + bool ok() const { return _destination->ok(); } bool send(const Packet & packet); bool sendDone(); void visit(); @@ -56,17 +62,14 @@ private: void startVisit(); void finalize(); bool visit(FastOS_FileInterface & file, DomainPart & dp) __attribute__((noinline)); - int32_t rpc(FRT_RPCRequest * req); - FRT_Supervisor & _supervisor; - FNET_Connection * _connection; - DomainSP _domain; - SerialNumRange _range; - int _id; - bool _ok; - std::atomic<bool> _visitRunning; - std::atomic<bool> _inSync; - std::atomic<bool> _finished; - time_point _startTime; + std::unique_ptr<Destination> _destination; + DomainSP _domain; + SerialNumRange _range; + int _id; + std::atomic<bool> _visitRunning; + std::atomic<bool> _inSync; + std::atomic<bool> _finished; + time_point _startTime; }; } diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogclient.cpp b/searchlib/src/vespa/searchlib/transactionlog/translogclient.cpp index aa2b558ea0c..767c8b45e10 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogclient.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/translogclient.cpp @@ -11,15 +11,40 @@ LOG_SETUP(".translogclient"); using namespace std::chrono_literals; +VESPA_THREAD_STACK_TAG(translogclient_rpc_callback) + namespace search::transactionlog { namespace { const double NEVER(-1.0); } +namespace { + +struct RpcTask : public vespalib::Executor::Task { + FRT_RPCRequest *req; + std::function<void(FRT_RPCRequest *req)> fun; + RpcTask(FRT_RPCRequest *req_in, std::function<void(FRT_RPCRequest *req)> &&fun_in) + : req(req_in), fun(std::move(fun_in)) {} + void run() override { + fun(req); + req->Return(); + req = nullptr; + } + ~RpcTask() { + if (req != nullptr) { + req->SetError(FRTE_RPC_METHOD_FAILED, "client has been shut down"); + req->Return(); + } + } +}; + +} + using vespalib::LockGuard; TransLogClient::TransLogClient(const vespalib::string & rpcTarget) : + _executor(1, 128 * 1024, translogclient_rpc_callback), _rpcTarget(rpcTarget), _sessions(), _supervisor(std::make_unique<FRT_Supervisor>()), @@ -33,6 +58,7 @@ TransLogClient::TransLogClient(const vespalib::string & rpcTarget) : TransLogClient::~TransLogClient() { disconnect(); + _executor.shutdown().sync(); _supervisor->ShutDown(true); } @@ -139,7 +165,7 @@ void TransLogClient::exportRPC(FRT_Supervisor & supervisor) FRT_ReflectionBuilder rb( & supervisor); //-- Visit Callbacks ----------------------------------------------------------- - rb.DefineMethod("visitCallback", "six", "i", false, FRT_METHOD(TransLogClient::visitCallbackRPC), this); + rb.DefineMethod("visitCallback", "six", "i", FRT_METHOD(TransLogClient::visitCallbackRPC_hook), this); rb.MethodDesc("Will return data asked from a subscriber/visitor."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("session", "Session handle."); @@ -147,14 +173,15 @@ void TransLogClient::exportRPC(FRT_Supervisor & supervisor) rb.ReturnDesc("result", "A resultcode(int) of the operation. Non zero number indicates error."); //-- Visit Callbacks ----------------------------------------------------------- - rb.DefineMethod("eofCallback", "si", "i", false, FRT_METHOD(TransLogClient::eofCallbackRPC), this); + rb.DefineMethod("eofCallback", "si", "i", FRT_METHOD(TransLogClient::eofCallbackRPC_hook), this); rb.MethodDesc("Will tell you that you are done with the visitor."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("session", "Session handle."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Non zero number indicates error."); } -void TransLogClient::visitCallbackRPC(FRT_RPCRequest *req) + +void TransLogClient::do_visitCallbackRPC(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -171,7 +198,7 @@ void TransLogClient::visitCallbackRPC(FRT_RPCRequest *req) LOG(debug, "visitCallback(%s, %d)=%d done", domainName, sessionId, retval); } -void TransLogClient::eofCallbackRPC(FRT_RPCRequest *req) +void TransLogClient::do_eofCallbackRPC(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -188,6 +215,16 @@ void TransLogClient::eofCallbackRPC(FRT_RPCRequest *req) LOG(debug, "eofCallback(%s, %d)=%d done", domainName, sessionId, retval); } +void TransLogClient::visitCallbackRPC_hook(FRT_RPCRequest *req) +{ + _executor.execute(std::make_unique<RpcTask>(req->Detach(), [this](FRT_RPCRequest *x){ do_visitCallbackRPC(x); })); +} + +void TransLogClient::eofCallbackRPC_hook(FRT_RPCRequest *req) +{ + _executor.execute(std::make_unique<RpcTask>(req->Detach(), [this](FRT_RPCRequest *x){ do_eofCallbackRPC(x); })); +} + TransLogClient::Session::Session(const vespalib::string & domain, TransLogClient & tlc) : _tlc(tlc), diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogclient.h b/searchlib/src/vespa/searchlib/transactionlog/translogclient.h index 87901890673..267d6e3b0ed 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogclient.h +++ b/searchlib/src/vespa/searchlib/transactionlog/translogclient.h @@ -5,6 +5,7 @@ #include <vespa/document/util/bytebuffer.h> #include <vespa/vespalib/util/sync.h> #include <vespa/vespalib/util/buffer.h> +#include <vespa/vespalib/util/threadstackexecutor.h> #include <vespa/fnet/frt/invokable.h> #include <map> #include <vector> @@ -96,8 +97,10 @@ public: const vespalib::string &getRPCTarget() const { return _rpcTarget; } private: void exportRPC(FRT_Supervisor & supervisor); - void visitCallbackRPC(FRT_RPCRequest *req); - void eofCallbackRPC(FRT_RPCRequest *req); + void do_visitCallbackRPC(FRT_RPCRequest *req); + void do_eofCallbackRPC(FRT_RPCRequest *req); + void visitCallbackRPC_hook(FRT_RPCRequest *req); + void eofCallbackRPC_hook(FRT_RPCRequest *req); int32_t rpc(FRT_RPCRequest * req); Session * findSession(const vespalib::string & domain, int sessionId); @@ -114,6 +117,7 @@ private: typedef std::map< SessionKey, Session * > SessionMap; + vespalib::ThreadStackExecutor _executor; vespalib::string _rpcTarget; SessionMap _sessions; //Brute force lock for subscriptions. For multithread safety. diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp index 65bb682a389..4b3e7bddb07 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp @@ -4,6 +4,8 @@ #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/fnet/frt/supervisor.h> +#include <vespa/fnet/frt/rpcrequest.h> +#include <vespa/fnet/task.h> #include <fstream> #include <vespa/log/log.h> @@ -26,21 +28,16 @@ class SyncHandler : public FNET_Task SerialNum _syncTo; public: - SyncHandler(FRT_Supervisor *supervisor, - FRT_RPCRequest *req,const Domain::SP &domain, - const TransLogServer::Session::SP &session, - SerialNum syncTo); + SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req,const Domain::SP &domain, + const TransLogServer::Session::SP &session, SerialNum syncTo); ~SyncHandler(); void PerformTask() override; }; -SyncHandler::SyncHandler(FRT_Supervisor *supervisor, - FRT_RPCRequest *req, - const Domain::SP &domain, - const TransLogServer::Session::SP &session, - SerialNum syncTo) +SyncHandler::SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req, const Domain::SP &domain, + const TransLogServer::Session::SP &session, SerialNum syncTo) : FNET_Task(supervisor->GetScheduler()), _req(*req), _domain(domain), @@ -50,9 +47,7 @@ SyncHandler::SyncHandler(FRT_Supervisor *supervisor, } -SyncHandler::~SyncHandler() -{ -} +SyncHandler::~SyncHandler() = default; void @@ -154,14 +149,16 @@ TransLogServer::~TransLogServer() _supervisor->ShutDown(true); } -bool TransLogServer::onStop() +bool +TransLogServer::onStop() { LOG(info, "Stopping TLS"); _reqQ.push(NULL); return true; } -void TransLogServer::run() +void +TransLogServer::run() { FRT_RPCRequest *req(NULL); bool hasPacket(false); @@ -236,7 +233,8 @@ TransLogServer::findDomain(stringref domainName) return domain; } -void TransLogServer::exportRPC(FRT_Supervisor & supervisor) +void +TransLogServer::exportRPC(FRT_Supervisor & supervisor) { _supervisor->SetSessionInitHook(FRT_METHOD(TransLogServer::initSession), this); _supervisor->SetSessionFiniHook(FRT_METHOD(TransLogServer::finiSession), this); @@ -244,32 +242,32 @@ void TransLogServer::exportRPC(FRT_Supervisor & supervisor) FRT_ReflectionBuilder rb( & supervisor); //-- Create Domain ----------------------------------------------------------- - rb.DefineMethod("createDomain", "s", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("createDomain", "s", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Create a new domain."); rb.ParamDesc("name", "The name of the domain."); rb.ReturnDesc("handle", "A handle(int) to the domain. Negative number indicates error."); //-- Delete Domain ----------------------------------------------------------- - rb.DefineMethod("deleteDomain", "s", "is", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("deleteDomain", "s", "is", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Create a new domain."); rb.ParamDesc("name", "The name of the domain."); rb.ReturnDesc("retval", "0 on success. Negative number indicates error."); rb.ReturnDesc("errormsg", "Message describing the error, if any."); //-- Open Domain ----------------------------------------------------------- - rb.DefineMethod("openDomain", "s", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("openDomain", "s", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Open an existing domain."); rb.ParamDesc("name", "The name of the domain."); rb.ReturnDesc("handle", "A handle(int) to the domain. Negative number indicates error."); //-- List Domains ----------------------------------------------------------- - rb.DefineMethod("listDomains", "", "is", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("listDomains", "", "is", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Will return a list of all the domains."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error."); rb.ReturnDesc("domains", "List of all the domains in a newline separated string"); //-- Domain Status ----------------------------------------------------------- - rb.DefineMethod("domainStatus", "s", "illl", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainStatus", "s", "illl", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("This will return key status information about the domain."); rb.ParamDesc("name", "The name of the domain."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error."); @@ -278,7 +276,7 @@ void TransLogServer::exportRPC(FRT_Supervisor & supervisor) rb.ReturnDesc("size", "Number of elements in the log."); //-- Domain Commit ----------------------------------------------------------- - rb.DefineMethod("domainCommit", "sx", "is", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainCommit", "sx", "is", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Will commit the data to the log."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("packet", "The data to commit to the domain."); @@ -286,14 +284,14 @@ void TransLogServer::exportRPC(FRT_Supervisor & supervisor) rb.ReturnDesc("message", "A textual description of the result code."); //-- Domain Prune ----------------------------------------------------------- - rb.DefineMethod("domainPrune", "sl", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainPrune", "sl", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Will erase all operations prior to the serial number."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("to", "Will erase all up and including."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error."); //-- Domain Visit ----------------------------------------------------------- - rb.DefineMethod("domainVisit", "sll", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainVisit", "sll", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("This will create a visitor that return all operations in the range."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("from", "Will return all entries following(not including) <from>."); @@ -301,21 +299,21 @@ void TransLogServer::exportRPC(FRT_Supervisor & supervisor) rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error. Positive number is the sessionid"); //-- Domain Session Run ----------------------------------------------------------- - rb.DefineMethod("domainSessionRun", "si", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainSessionRun", "si", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("This will start the session thread."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("sessionid", "The session identifier."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error."); //-- Domain Session Close ----------------------------------------------------------- - rb.DefineMethod("domainSessionClose", "si", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainSessionClose", "si", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("This will close the session."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("sessionid", "The session identifier."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error. 1 means busy -> retry. 0 is OK."); //-- Domain Sync -- - rb.DefineMethod("domainSync", "sl", "il", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainSync", "sl", "il", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Sync domain to given entry"); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("syncto", "Entry to sync to"); @@ -325,6 +323,8 @@ void TransLogServer::exportRPC(FRT_Supervisor & supervisor) namespace { +constexpr double NEVER(-1.0); + void writeDomainDir(std::lock_guard<std::mutex> &guard, vespalib::string dir, @@ -344,9 +344,77 @@ writeDomainDir(std::lock_guard<std::mutex> &guard, vespalib::File::sync(dir); } +class RPCDestination : public Session::Destination { +public: + RPCDestination(FRT_Supervisor & supervisor, FNET_Connection * connection) + : _supervisor(supervisor), _connection(connection), _ok(true) + { + _connection->AddRef(); + } + ~RPCDestination() override { _connection->SubRef(); } + + bool ok() const override { + return _ok; + } + + bool send(int32_t id, const vespalib::string & domain, const Packet & packet) override { + FRT_RPCRequest *req = _supervisor.AllocRPCRequest(); + req->SetMethodName("visitCallback"); + req->GetParams()->AddString(domain.c_str()); + req->GetParams()->AddInt32(id); + req->GetParams()->AddData(packet.getHandle().c_str(), packet.getHandle().size()); + return send(req); + } + + bool sendDone(int32_t id, const vespalib::string & domain) override { + FRT_RPCRequest *req = _supervisor.AllocRPCRequest(); + req->SetMethodName("eofCallback"); + req->GetParams()->AddString(domain.c_str()); + req->GetParams()->AddInt32(id); + bool retval(send(req)); + return retval; + } + bool connected() const override { + return (_connection->GetState() <= FNET_Connection::FNET_CONNECTED); + } +private: + bool send(FRT_RPCRequest * req) { + int32_t retval = rpc(req); + if ( ! ((retval == RPC::OK) || (retval == FRTE_RPC_CONNECTION)) ) { + LOG(error, "Return value != OK(%d) in send for method 'visitCallback'.", retval); + } + req->SubRef(); + + return (retval == RPC::OK); + } + int32_t rpc(FRT_RPCRequest * req) { + int32_t retval(-7); + LOG(debug, "rpc %s starting.", req->GetMethodName()); + FRT_Supervisor::InvokeSync(_supervisor.GetTransport(), _connection, req, NEVER); + if (req->GetErrorCode() == FRTE_NO_ERROR) { + retval = (req->GetReturn()->GetValue(0)._intval32); + LOG(debug, "rpc %s = %d\n", req->GetMethodName(), retval); + } else if (req->GetErrorCode() == FRTE_RPC_TIMEOUT) { + LOG(warning, "rpc %s timed out. Will allow to continue: error(%d): %s\n", req->GetMethodName(), req->GetErrorCode(), req->GetErrorMessage()); + retval = -req->GetErrorCode(); + } else { + if (req->GetErrorCode() != FRTE_RPC_CONNECTION) { + LOG(warning, "rpc %s: error(%d): %s\n", req->GetMethodName(), req->GetErrorCode(), req->GetErrorMessage()); + } + retval = -req->GetErrorCode(); + _ok = false; + } + return retval; + } + FRT_Supervisor & _supervisor; + FNET_Connection * _connection; + bool _ok; +}; + } -void TransLogServer::createDomain(FRT_RPCRequest *req) +void +TransLogServer::createDomain(FRT_RPCRequest *req) { uint32_t retval(0); FRT_Values & params = *req->GetParams(); @@ -373,7 +441,8 @@ void TransLogServer::createDomain(FRT_RPCRequest *req) ret.AddInt32(retval); } -void TransLogServer::deleteDomain(FRT_RPCRequest *req) +void +TransLogServer::deleteDomain(FRT_RPCRequest *req) { uint32_t retval(0); vespalib::string msg("ok"); @@ -410,7 +479,8 @@ void TransLogServer::deleteDomain(FRT_RPCRequest *req) ret.AddString(msg.c_str()); } -void TransLogServer::openDomain(FRT_RPCRequest *req) +void +TransLogServer::openDomain(FRT_RPCRequest *req) { uint32_t retval(0); FRT_Values & params = *req->GetParams(); @@ -427,7 +497,8 @@ void TransLogServer::openDomain(FRT_RPCRequest *req) ret.AddInt32(retval); } -void TransLogServer::listDomains(FRT_RPCRequest *req) +void +TransLogServer::listDomains(FRT_RPCRequest *req) { FRT_Values & ret = *req->GetReturn(); LOG(debug, "listDomains()"); @@ -442,7 +513,8 @@ void TransLogServer::listDomains(FRT_RPCRequest *req) ret.AddString(domains.c_str()); } -void TransLogServer::domainStatus(FRT_RPCRequest *req) +void +TransLogServer::domainStatus(FRT_RPCRequest *req) { FRT_Values & params = *req->GetParams(); FRT_Values & ret = *req->GetReturn(); @@ -462,7 +534,8 @@ void TransLogServer::domainStatus(FRT_RPCRequest *req) } } -void TransLogServer::commit(const vespalib::string & domainName, const Packet & packet, DoneCallback done) +void +TransLogServer::commit(const vespalib::string & domainName, const Packet & packet, DoneCallback done) { (void) done; Domain::SP domain(findDomain(domainName)); @@ -473,7 +546,8 @@ void TransLogServer::commit(const vespalib::string & domainName, const Packet & } } -void TransLogServer::domainCommit(FRT_RPCRequest *req) +void +TransLogServer::domainCommit(FRT_RPCRequest *req) { FRT_Values & params = *req->GetParams(); FRT_Values & ret = *req->GetReturn(); @@ -496,7 +570,8 @@ void TransLogServer::domainCommit(FRT_RPCRequest *req) } } -void TransLogServer::domainVisit(FRT_RPCRequest *req) +void +TransLogServer::domainVisit(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -508,12 +583,13 @@ void TransLogServer::domainVisit(FRT_RPCRequest *req) SerialNum from(params[1]._intval64); SerialNum to(params[2]._intval64); LOG(debug, "domainVisit(%s, %" PRIu64 ", %" PRIu64 ")", domainName, from, to); - retval = domain->visit(domain, from, to, *_supervisor, req->GetConnection()); + retval = domain->visit(domain, from, to, std::make_unique<RPCDestination>(*_supervisor, req->GetConnection())); } ret.AddInt32(retval); } -void TransLogServer::domainSessionRun(FRT_RPCRequest *req) +void +TransLogServer::domainSessionRun(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -529,13 +605,15 @@ void TransLogServer::domainSessionRun(FRT_RPCRequest *req) ret.AddInt32(retval); } -void TransLogServer::relayToThreadRPC(FRT_RPCRequest *req) +void +TransLogServer::relayToThreadRPC(FRT_RPCRequest *req) { req->Detach(); _reqQ.push(req); } -void TransLogServer::domainSessionClose(FRT_RPCRequest *req) +void +TransLogServer::domainSessionClose(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -552,7 +630,8 @@ void TransLogServer::domainSessionClose(FRT_RPCRequest *req) ret.AddInt32(retval); } -void TransLogServer::domainPrune(FRT_RPCRequest *req) +void +TransLogServer::domainPrune(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -572,7 +651,6 @@ void TransLogServer::domainPrune(FRT_RPCRequest *req) ret.AddInt32(retval); } - const TransLogServer::Session::SP & TransLogServer::getSession(FRT_RPCRequest *req) { @@ -582,14 +660,12 @@ TransLogServer::getSession(FRT_RPCRequest *req) return *sessionspp; } - void TransLogServer::initSession(FRT_RPCRequest *req) { req->GetConnection()->SetContext(new Session::SP(new Session())); } - void TransLogServer::finiSession(FRT_RPCRequest *req) { @@ -600,14 +676,12 @@ TransLogServer::finiSession(FRT_RPCRequest *req) delete sessionspp; } - void TransLogServer::downSession(FRT_RPCRequest *req) { getSession(req)->setDown(); } - void TransLogServer::domainSync(FRT_RPCRequest *req) { diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.h b/searchlib/src/vespa/searchlib/transactionlog/translogserver.h index 189be8c38d8..8aedfef6d8d 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.h +++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.h @@ -8,6 +8,9 @@ #include <vespa/fnet/frt/invokable.h> #include <mutex> + +class FRT_Supervisor; + namespace search::common { class FileHeaderContext; } namespace search::transactionlog { diff --git a/slobrok/src/tests/mirrorapi/mirrorapi.cpp b/slobrok/src/tests/mirrorapi/mirrorapi.cpp index 0550bf51b0c..f77dfd80986 100644 --- a/slobrok/src/tests/mirrorapi/mirrorapi.cpp +++ b/slobrok/src/tests/mirrorapi/mirrorapi.cpp @@ -41,7 +41,7 @@ Server::Server(std::string name, int port, std::string slobrokSpec) { FRT_ReflectionBuilder rb(&_orb); //--------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(Server::rpc_listNamesServed), this); rb.MethodDesc("Look up a rpcserver"); rb.ReturnDesc("names", "The rpcserver names on this server"); diff --git a/slobrok/src/tests/oldapi/old.cpp b/slobrok/src/tests/oldapi/old.cpp index 77bca6dfe90..42cec186a08 100644 --- a/slobrok/src/tests/oldapi/old.cpp +++ b/slobrok/src/tests/oldapi/old.cpp @@ -39,7 +39,7 @@ Server::Server(std::string name, int port, std::string slobrokSpec) { FRT_ReflectionBuilder rb(&_orb); //--------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(Server::rpc_listNamesServed), this); rb.MethodDesc("Look up a rpcserver"); rb.ReturnDesc("names", "The rpcserver names on this server"); diff --git a/slobrok/src/tests/standalone/standalone.cpp b/slobrok/src/tests/standalone/standalone.cpp index 63f8b1d2c59..136f8125c8b 100644 --- a/slobrok/src/tests/standalone/standalone.cpp +++ b/slobrok/src/tests/standalone/standalone.cpp @@ -26,7 +26,7 @@ Server::Server(std::string name, int port) { FRT_ReflectionBuilder rb(&_orb); //--------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(Server::rpc_listNamesServed), this); rb.MethodDesc("Look up a rpcserver"); rb.ReturnDesc("names", "The rpcserver names on this server"); diff --git a/slobrok/src/tests/startsome/tstdst.cpp b/slobrok/src/tests/startsome/tstdst.cpp index 44b42e1ff83..4723b3819d7 100644 --- a/slobrok/src/tests/startsome/tstdst.cpp +++ b/slobrok/src/tests/startsome/tstdst.cpp @@ -87,12 +87,12 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) FRT_ReflectionBuilder rb(supervisor); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(RPCHooks::rpc_listNamesServed), this); rb.MethodDesc("Look up a rpcserver"); rb.ReturnDesc("names", "The rpcserver names on this server"); //------------------------------------------------------------------------- - rb.DefineMethod("system.stop", "", "", true, + rb.DefineMethod("system.stop", "", "", FRT_METHOD(RPCHooks::rpc_stop), this); rb.MethodDesc("Shut down the application"); //------------------------------------------------------------------------- diff --git a/slobrok/src/vespa/slobrok/sbregister.cpp b/slobrok/src/vespa/slobrok/sbregister.cpp index 8f8e42a39aa..a1346feeece 100644 --- a/slobrok/src/vespa/slobrok/sbregister.cpp +++ b/slobrok/src/vespa/slobrok/sbregister.cpp @@ -277,12 +277,12 @@ RegisterAPI::RPCHooks::RPCHooks(RegisterAPI &owner) { FRT_ReflectionBuilder rb(&_owner._orb); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(RPCHooks::rpc_listNamesServed), this); rb.MethodDesc("List rpcserver names"); rb.ReturnDesc("names", "The rpcserver names this server wants to serve"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.notifyUnregistered", "s", "", true, + rb.DefineMethod("slobrok.callback.notifyUnregistered", "s", "", FRT_METHOD(RPCHooks::rpc_notifyUnregistered), this); rb.MethodDesc("Notify a server about removed registration"); rb.ParamDesc("name", "RpcServer name"); diff --git a/slobrok/src/vespa/slobrok/server/rpchooks.cpp b/slobrok/src/vespa/slobrok/server/rpchooks.cpp index 33cc10937df..82e30a309a1 100644 --- a/slobrok/src/vespa/slobrok/server/rpchooks.cpp +++ b/slobrok/src/vespa/slobrok/server/rpchooks.cpp @@ -81,39 +81,39 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) FRT_ReflectionBuilder rb(supervisor); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.system.resume", "", "", true, + rb.DefineMethod("slobrok.system.resume", "", "", FRT_METHOD(RPCHooks::rpc_resume), this); rb.MethodDesc("Enable something - currently NOP"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.system.suspend", "", "", true, + rb.DefineMethod("slobrok.system.suspend", "", "", FRT_METHOD(RPCHooks::rpc_suspend), this); rb.MethodDesc("Disable something - currently NOP"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.system.version", "", "s", true, + rb.DefineMethod("slobrok.system.version", "", "s", FRT_METHOD(RPCHooks::rpc_version), this); rb.MethodDesc("Get location broker version"); rb.ReturnDesc("version", "version string"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.system.stop", "", "", true, + rb.DefineMethod("slobrok.system.stop", "", "", FRT_METHOD(RPCHooks::rpc_stop), this); rb.MethodDesc("Shut down the location broker application"); //------------------------------------------------------------------------- //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.internal.listManagedRpcServers", "", "SS", true, + rb.DefineMethod("slobrok.internal.listManagedRpcServers", "", "SS", FRT_METHOD(RPCHooks::rpc_listManagedRpcServers), this); rb.MethodDesc("List all rpcservers managed by this location broker"); rb.ReturnDesc("names", "Managed rpcserver names"); rb.ReturnDesc("specs", "The connection specifications (in same order)"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.internal.lookupManaged", "s", "ss", true, + rb.DefineMethod("slobrok.internal.lookupManaged", "s", "ss", FRT_METHOD(RPCHooks::rpc_lookupManaged), this); rb.MethodDesc("Lookup a specific rpcserver managed by this location broker"); rb.ParamDesc("name", "Name of rpc server"); rb.ReturnDesc("name", "Name of rpc server"); rb.ReturnDesc("spec", "The connection specification"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.internal.wantAdd", "sss", "is", true, + rb.DefineMethod("slobrok.internal.wantAdd", "sss", "is", FRT_METHOD(RPCHooks::rpc_wantAdd), this); rb.MethodDesc("remote location broker wants to add a rpcserver"); rb.ParamDesc("slobrok", "Name of remote location broker"); @@ -122,7 +122,7 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) rb.ReturnDesc("denied", "non-zero if request was denied"); rb.ReturnDesc("reason", "reason for denial"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.internal.doAdd", "sss", "is", true, + rb.DefineMethod("slobrok.internal.doAdd", "sss", "is", FRT_METHOD(RPCHooks::rpc_doAdd), this); rb.MethodDesc("add rpcserver managed by remote location broker"); rb.ParamDesc("slobrok", "Name of remote location broker"); @@ -131,7 +131,7 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) rb.ReturnDesc("denied", "non-zero if request was denied"); rb.ReturnDesc("reason", "reason for denial"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.internal.doRemove", "sss", "is", true, + rb.DefineMethod("slobrok.internal.doRemove", "sss", "is", FRT_METHOD(RPCHooks::rpc_doRemove), this); rb.MethodDesc("remove rpcserver managed by remote location broker"); rb.ParamDesc("slobrok", "Name of remote location broker"); @@ -142,31 +142,31 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) //------------------------------------------------------------------------- //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(RPCHooks::rpc_listNamesServed), this); rb.MethodDesc("List rpcservers served"); rb.ReturnDesc("names", "The rpcserver names this server wants to serve"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.notifyUnregistered", "s", "", true, + rb.DefineMethod("slobrok.callback.notifyUnregistered", "s", "", FRT_METHOD(RPCHooks::rpc_notifyUnregistered), this); rb.MethodDesc("Notify a server about removed registration"); rb.ParamDesc("name", "NamedService name"); //------------------------------------------------------------------------- //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.admin.removePeer", "ss", "", true, + rb.DefineMethod("slobrok.admin.removePeer", "ss", "", FRT_METHOD(RPCHooks::rpc_removePeer), this); rb.MethodDesc("stop syncing with other location broker"); rb.ParamDesc("slobrok", "NamedService name of remote location broker"); rb.ParamDesc("spec", "Connection specification of remote location broker"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.admin.addPeer", "ss", "", true, + rb.DefineMethod("slobrok.admin.addPeer", "ss", "", FRT_METHOD(RPCHooks::rpc_addPeer), this); rb.MethodDesc("sync our information with other location broker"); rb.ParamDesc("slobrok", "NamedService name of remote location broker"); rb.ParamDesc("spec", "Connection specification of remote location broker"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.admin.listAllRpcServers", "", "SSS", true, + rb.DefineMethod("slobrok.admin.listAllRpcServers", "", "SSS", FRT_METHOD(RPCHooks::rpc_listAllRpcServers), this); rb.MethodDesc("List all known rpcservers"); rb.ReturnDesc("names", "NamedService names"); @@ -175,13 +175,13 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) //------------------------------------------------------------------------- //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.unregisterRpcServer", "ss", "", true, + rb.DefineMethod("slobrok.unregisterRpcServer", "ss", "", FRT_METHOD(RPCHooks::rpc_unregisterRpcServer), this); rb.MethodDesc("Unregister a rpcserver"); rb.ParamDesc("name", "NamedService name"); rb.ParamDesc("spec", "The connection specification"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.registerRpcServer", "ss", "", true, + rb.DefineMethod("slobrok.registerRpcServer", "ss", "", FRT_METHOD(RPCHooks::rpc_registerRpcServer), this); rb.MethodDesc("Register a rpcserver"); rb.ParamDesc("name", "NamedService name"); @@ -189,7 +189,7 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) //------------------------------------------------------------------------- //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.mirror.fetch", "ii", "SSi", true, + rb.DefineMethod("slobrok.mirror.fetch", "ii", "SSi", FRT_METHOD(RPCHooks::rpc_mirrorFetch), this); rb.MethodDesc("Fetch or update mirror of name to spec map"); rb.ParamDesc("gencnt", "generation already known by client"); @@ -199,7 +199,7 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) rb.ReturnDesc("specs", "Array of connection specifications (same order)"); rb.ReturnDesc("newgen", "Generation count for new version of the map"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.incremental.fetch", "ii", "iSSSi", true, + rb.DefineMethod("slobrok.incremental.fetch", "ii", "iSSSi", FRT_METHOD(RPCHooks::rpc_incrementalFetch), this); rb.MethodDesc("Fetch or update mirror of name to spec map"); rb.ParamDesc("gencnt", "generation already known by client"); @@ -212,7 +212,7 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) rb.ReturnDesc("specs", "Array of connection specifications (same order)"); rb.ReturnDesc("newgen", "Generation count for new version of the map"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.lookupRpcServer", "s", "SS", true, + rb.DefineMethod("slobrok.lookupRpcServer", "s", "SS", FRT_METHOD(RPCHooks::rpc_lookupRpcServer), this); rb.MethodDesc("Look up rpcservers"); rb.ParamDesc("pattern", "The pattern of the rpcservers to lookup.\n" diff --git a/staging_vespalib/src/tests/util/process_memory_stats/process_memory_stats_test.cpp b/staging_vespalib/src/tests/util/process_memory_stats/process_memory_stats_test.cpp index 45e75547fb2..e4274664336 100644 --- a/staging_vespalib/src/tests/util/process_memory_stats/process_memory_stats_test.cpp +++ b/staging_vespalib/src/tests/util/process_memory_stats/process_memory_stats_test.cpp @@ -80,4 +80,12 @@ TEST("grow mapped memory") munmap(mapAddr, mapLen); } +TEST("order samples") +{ + ProcessMemoryStats a(0,0,0,7,0); + ProcessMemoryStats b(0,0,0,8,0); + EXPECT_TRUE(a < b); + EXPECT_FALSE(b < a); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.cpp b/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.cpp index 138e7a25803..f0cbefd443b 100644 --- a/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.cpp +++ b/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.cpp @@ -4,6 +4,7 @@ #include <vespa/vespalib/stllike/asciistream.h> #include <fstream> #include <sstream> +#include <algorithm> #include <vespa/log/log.h> @@ -182,18 +183,20 @@ ProcessMemoryStats::toString() const ProcessMemoryStats ProcessMemoryStats::create(uint64_t sizeEpsilon) { - ProcessMemoryStats prevStats = createStatsFromSmaps(); - const size_t NUM_TRIES = 10; + constexpr size_t NUM_TRIES = 10; + std::vector<ProcessMemoryStats> samples; + samples.reserve(NUM_TRIES); + samples.push_back(createStatsFromSmaps()); for (size_t i = 0; i < NUM_TRIES; ++i) { - ProcessMemoryStats currStats = createStatsFromSmaps(); - if (prevStats.similarTo(currStats, sizeEpsilon)) { - return prevStats; + samples.push_back(createStatsFromSmaps()); + if (samples.back().similarTo(*(samples.rbegin()+1), sizeEpsilon)) { + return samples.back(); } LOG(info, "create(): Memory stats have changed, trying to read smaps file again: i=%zu, prevStats={%s}, currStats={%s}", - i, prevStats.toString().c_str(), currStats.toString().c_str()); - prevStats = currStats; + i, (samples.rbegin()+1)->toString().c_str(), samples.back().toString().c_str()); } - return prevStats; + std::sort(samples.begin(), samples.end()); + return samples[samples.size()/2]; } } diff --git a/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.h b/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.h index fe5062f75cd..3870a2e2907 100644 --- a/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.h +++ b/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.h @@ -2,7 +2,6 @@ #pragma once -#include <cstdint> #include <vespa/vespalib/stllike/string.h> namespace vespalib { @@ -37,6 +36,7 @@ public: uint64_t getMappingsCount() const { return _mappings_count; } bool similarTo(const ProcessMemoryStats &rhs, uint64_t sizeEpsilon) const; vespalib::string toString() const; + bool operator < (const ProcessMemoryStats & rhs) const { return _anonymous_rss < rhs._anonymous_rss; } /** for unit tests only */ ProcessMemoryStats(uint64_t, uint64_t, uint64_t, uint64_t, uint64_t); diff --git a/standalone-container/src/main/java/com/yahoo/container/standalone/CloudConfigInstallVariables.java b/standalone-container/src/main/java/com/yahoo/container/standalone/CloudConfigInstallVariables.java index be91eceee2f..0be4a55275c 100644 --- a/standalone-container/src/main/java/com/yahoo/container/standalone/CloudConfigInstallVariables.java +++ b/standalone-container/src/main/java/com/yahoo/container/standalone/CloudConfigInstallVariables.java @@ -56,7 +56,7 @@ public class CloudConfigInstallVariables implements CloudConfigOptions { @Override public Optional<Integer> zookeeperQuorumPort() { - return getInstallVariable("zookeeper_quoromPort", Integer::parseInt); + return getInstallVariable("zookeeper_quorumPort", Integer::parseInt); } @Override diff --git a/standalone-container/src/main/java/com/yahoo/container/standalone/StandaloneContainerApplication.java b/standalone-container/src/main/java/com/yahoo/container/standalone/StandaloneContainerApplication.java index d88245587d2..8d2cd429517 100644 --- a/standalone-container/src/main/java/com/yahoo/container/standalone/StandaloneContainerApplication.java +++ b/standalone-container/src/main/java/com/yahoo/container/standalone/StandaloneContainerApplication.java @@ -20,6 +20,7 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.model.application.provider.StaticConfigDefinitionRepo; import com.yahoo.config.model.builder.xml.ConfigModelId; import com.yahoo.config.model.builder.xml.XmlHelper; +import com.yahoo.config.model.deploy.DeployProperties; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.provision.Zone; import com.yahoo.container.di.config.SubscriberFactory; @@ -209,17 +210,22 @@ public class StandaloneContainerApplication implements Application { } private static ContainerModelBuilder newContainerModelBuilder(Networking networkingOption) { + return isConfigServer() ? + new ConfigServerContainerModelBuilder(new CloudConfigInstallVariables()) : + new ContainerModelBuilder(true, networkingOption); + } + + private static boolean isConfigServer() { Optional<String> profile = optionalInstallVariable(DEPLOYMENT_PROFILE_INSTALL_VARIABLE); if (profile.isPresent()) { String profileName = profile.get(); - if ("configserver".equals(profileName)) { - return new ConfigServerContainerModelBuilder(new CloudConfigInstallVariables()); - } else { + if (profileName.equals("configserver")) + return true; + else throw new RuntimeException("Invalid deployment profile '" + profileName + "'"); - } - } else { - return new ContainerModelBuilder(true, networkingOption); } + + return false; } static Pair<VespaModel, Container> createContainerModel(Path applicationPath, FileRegistry fileRegistry, @@ -229,8 +235,7 @@ public class StandaloneContainerApplication implements Application { .includeSourceFiles(true).preprocessedDir(preprocessedApplicationDir).build(); ApplicationPackage applicationPackage = rawApplicationPackage.preprocess(Zone.defaultZone(), logger); validateApplication(applicationPackage); - DeployState deployState = new DeployState.Builder().applicationPackage(applicationPackage).fileRegistry(fileRegistry) - .deployLogger(logger).configDefinitionRepo(configDefinitionRepo).build(); + DeployState deployState = createDeployState(applicationPackage, fileRegistry, logger); VespaModel root = VespaModel.createIncomplete(deployState); ApplicationConfigProducerRoot vespaRoot = new ApplicationConfigProducerRoot(root, "vespa", deployState.getDocumentModel(), @@ -252,6 +257,23 @@ public class StandaloneContainerApplication implements Application { return new Pair<>(root, container); } + private static DeployState createDeployState(ApplicationPackage applicationPackage, FileRegistry fileRegistry, DeployLogger logger) { + DeployState.Builder builder = new DeployState.Builder() + .applicationPackage(applicationPackage) + .fileRegistry(fileRegistry) + .deployLogger(logger) + .configDefinitionRepo(configDefinitionRepo); + + /* Temporarily disable until we know how status.html is updated for config servers/controllers + if (isConfigServer()) + builder.properties(new DeployProperties.Builder() + .hostedVespa(new CloudConfigInstallVariables().hostedVespa().orElse(Boolean.FALSE)) + .build()); + */ + + return builder.build(); + } + private static void initializeContainer(Container container, Element spec) { HostResource host = container.getRoot().getHostSystem().getHost(Container.SINGLENODE_CONTAINER_SERVICESPEC); diff --git a/storage/src/vespa/storage/storageserver/fnetlistener.cpp b/storage/src/vespa/storage/storageserver/fnetlistener.cpp index 1a72190b2a6..e31bded772c 100644 --- a/storage/src/vespa/storage/storageserver/fnetlistener.cpp +++ b/storage/src/vespa/storage/storageserver/fnetlistener.cpp @@ -65,7 +65,7 @@ FNetListener::initRPC() { FRT_ReflectionBuilder rb(_orb.get()); - rb.DefineMethod("getnodestate3", "sii", "ss", true, FRT_METHOD(FNetListener::RPC_getNodeState2), this); + rb.DefineMethod("getnodestate3", "sii", "ss", FRT_METHOD(FNetListener::RPC_getNodeState2), this); rb.MethodDesc("Get state of this node"); rb.ParamDesc("nodestate", "Expected state of given node. If correct, the " "request will be queued on target until it changes. To not give " @@ -74,7 +74,7 @@ FNetListener::initRPC() rb.ReturnDesc("nodestate", "State string for this node"); rb.ReturnDesc("hostinfo", "Information about host this node is running on"); //------------------------------------------------------------------------- - rb.DefineMethod("getnodestate2", "si", "s", true, FRT_METHOD(FNetListener::RPC_getNodeState2), this); + rb.DefineMethod("getnodestate2", "si", "s", FRT_METHOD(FNetListener::RPC_getNodeState2), this); rb.MethodDesc("Get state of this node"); rb.ParamDesc("nodestate", "Expected state of given node. If correct, the " "request will be queued on target until it changes. To not give " @@ -82,17 +82,17 @@ FNetListener::initRPC() rb.ParamDesc("timeout", "Timeout of message in milliseconds, set by the state requester"); rb.ReturnDesc("nodestate", "State string for this node"); //------------------------------------------------------------------------- - rb.DefineMethod("setsystemstate2", "s", "", true, FRT_METHOD(FNetListener::RPC_setSystemState2), this); + rb.DefineMethod("setsystemstate2", "s", "", FRT_METHOD(FNetListener::RPC_setSystemState2), this); rb.MethodDesc("Set systemstate on this node"); rb.ParamDesc("systemstate", "New systemstate to set"); //------------------------------------------------------------------------- - rb.DefineMethod("setdistributionstates", "bix", "", true, FRT_METHOD(FNetListener::RPC_setDistributionStates), this); + rb.DefineMethod("setdistributionstates", "bix", "", FRT_METHOD(FNetListener::RPC_setDistributionStates), this); rb.MethodDesc("Set distribution states for cluster and bucket spaces"); rb.ParamDesc("compressionType", "Compression type for payload"); rb.ParamDesc("uncompressedSize", "Uncompressed size for payload"); rb.ParamDesc("payload", "Binary Slime format payload"); //------------------------------------------------------------------------- - rb.DefineMethod("getcurrenttime", "", "lis", true, FRT_METHOD(FNetListener::RPC_getCurrentTime), this); + rb.DefineMethod("getcurrenttime", "", "lis", FRT_METHOD(FNetListener::RPC_getCurrentTime), this); rb.MethodDesc("Get current time on this node"); rb.ReturnDesc("seconds", "Current time in seconds since epoch"); rb.ReturnDesc("nanoseconds", "additional nanoseconds since epoch"); diff --git a/travis/travis-build-full.sh b/travis/travis-build-full.sh index 53e174b534b..6c11502c355 100755 --- a/travis/travis-build-full.sh +++ b/travis/travis-build-full.sh @@ -14,7 +14,7 @@ ccache --print-config cd ${SOURCE_DIR} sh ./bootstrap.sh java -mvn install --no-snapshot-updates --batch-mode --threads ${NUM_THREADS} +mvn -V install --no-snapshot-updates --batch-mode --threads ${NUM_THREADS} bash ${SOURCE_DIR}/bootstrap-cmake.sh ${SOURCE_DIR} make -j ${NUM_THREADS} ctest3 --output-on-failure -j ${NUM_THREADS} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateDeserializer.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateDeserializer.java index 5dd6ceb16b4..59f10a78a58 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateDeserializer.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateDeserializer.java @@ -4,7 +4,7 @@ package com.yahoo.vespa.athenz.client.zts.bindings.serializers; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonDeserializer; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; +import com.yahoo.security.X509CertificateUtils; import java.io.IOException; import java.security.cert.X509Certificate; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateListDeserializer.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateListDeserializer.java index c496031c116..64b23af9295 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateListDeserializer.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateListDeserializer.java @@ -4,7 +4,7 @@ package com.yahoo.vespa.athenz.client.zts.bindings.serializers; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonDeserializer; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; +import com.yahoo.security.X509CertificateUtils; import java.io.IOException; import java.security.cert.X509Certificate; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identity/SiaIdentityProvider.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identity/SiaIdentityProvider.java index b06ae089b2a..d8fa910aa73 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identity/SiaIdentityProvider.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identity/SiaIdentityProvider.java @@ -5,8 +5,8 @@ import com.google.inject.Inject; import com.yahoo.component.AbstractComponent; import com.yahoo.log.LogLevel; import com.yahoo.vespa.athenz.api.AthenzService; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.athenz.tls.SslContextBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.SslContextBuilder; import com.yahoo.vespa.athenz.utils.SiaUtils; import javax.net.ssl.SSLContext; @@ -92,8 +92,8 @@ public class SiaIdentityProvider extends AbstractComponent implements ServiceIde private SSLContext createIdentitySslContext() { return new SslContextBuilder() - .withTrustStore(trustStoreFile, KeyStoreType.JKS) - .withKeyStore(privateKeyFile, certificateFile) + .withTrustStore(trustStoreFile.toPath(), KeyStoreType.JKS) + .withKeyStore(privateKeyFile.toPath(), certificateFile.toPath()) .build(); } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java index 5567831d49d..4a189c872bc 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java @@ -11,10 +11,10 @@ import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocumentClient; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyUtils; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SslContextBuilder; import com.yahoo.vespa.athenz.tls.Pkcs10Csr; -import com.yahoo.vespa.athenz.tls.SslContextBuilder; import com.yahoo.vespa.athenz.utils.SiaUtils; import com.yahoo.vespa.defaults.Defaults; @@ -31,7 +31,7 @@ import java.time.Clock; import java.time.Duration; import java.util.Optional; -import static com.yahoo.vespa.athenz.tls.KeyStoreType.JKS; +import static com.yahoo.security.KeyStoreType.JKS; import static java.util.Collections.singleton; /** @@ -153,7 +153,7 @@ class AthenzCredentialsService { private SSLContext createIdentitySslContext(PrivateKey privateKey, X509Certificate certificate) { return new SslContextBuilder() .withKeyStore(privateKey, certificate) - .withTrustStore(trustStoreJks, JKS) + .withTrustStore(trustStoreJks.toPath(), JKS) .build(); } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java index 266e2ebcefd..e318ebeb7fd 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java @@ -19,8 +19,8 @@ import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; import com.yahoo.vespa.athenz.client.zts.ZtsClient; import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; import com.yahoo.vespa.athenz.identity.SiaIdentityProvider; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.athenz.tls.SslContextBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.SslContextBuilder; import com.yahoo.vespa.athenz.utils.SiaUtils; import com.yahoo.vespa.defaults.Defaults; @@ -177,7 +177,7 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen X509Certificate roleCertificate = client.getRoleCertificate(role, credentials.getKeyPair(), dnsSuffix); return new SslContextBuilder() .withKeyStore(credentials.getKeyPair().getPrivate(), roleCertificate) - .withTrustStore(getDefaultTrustStoreLocation(), KeyStoreType.JKS) + .withTrustStore(getDefaultTrustStoreLocation().toPath(), KeyStoreType.JKS) .build(); } } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzX509CertificateUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzX509CertificateUtils.java index 46aca707be1..33e5552eaf6 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzX509CertificateUtils.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzX509CertificateUtils.java @@ -8,7 +8,7 @@ import com.yahoo.vespa.athenz.utils.AthenzIdentities; import java.security.cert.X509Certificate; import java.util.List; -import static com.yahoo.vespa.athenz.tls.SubjectAlternativeName.Type.RFC822_NAME; +import static com.yahoo.security.SubjectAlternativeName.Type.RFC822_NAME; /** * Utility methods for Athenz issued x509 certificates @@ -23,26 +23,26 @@ public class AthenzX509CertificateUtils { public static boolean isAthenzRoleCertificate(X509Certificate certificate) { return isAthenzIssuedCertificate(certificate) && - X509CertificateUtils.getSubjectCommonNames(certificate).get(0).contains(COMMON_NAME_ROLE_DELIMITER); + com.yahoo.security.X509CertificateUtils.getSubjectCommonNames(certificate).get(0).contains(COMMON_NAME_ROLE_DELIMITER); } public static boolean isAthenzIssuedCertificate(X509Certificate certificate) { - return X509CertificateUtils.getIssuerCommonNames(certificate).stream() + return com.yahoo.security.X509CertificateUtils.getIssuerCommonNames(certificate).stream() .anyMatch(cn -> cn.equalsIgnoreCase("Yahoo Athenz CA") || cn.equalsIgnoreCase("Athenz AWS CA")); } public static AthenzIdentity getIdentityFromRoleCertificate(X509Certificate certificate) { - List<SubjectAlternativeName> sans = X509CertificateUtils.getSubjectAlternativeNames(certificate); + List<com.yahoo.security.SubjectAlternativeName> sans = com.yahoo.security.X509CertificateUtils.getSubjectAlternativeNames(certificate); return sans.stream() .filter(san -> san.getType() == RFC822_NAME) - .map(SubjectAlternativeName::getValue) + .map(com.yahoo.security.SubjectAlternativeName::getValue) .map(AthenzX509CertificateUtils::getIdentityFromSanEmail) .findFirst() .orElseThrow(() -> new IllegalArgumentException("Could not find identity in SAN: " + sans)); } public static AthenzRole getRolesFromRoleCertificate(X509Certificate certificate) { - String commonName = X509CertificateUtils.getSubjectCommonNames(certificate).get(0); + String commonName = com.yahoo.security.X509CertificateUtils.getSubjectCommonNames(certificate).get(0); int delimiterIndex = commonName.indexOf(COMMON_NAME_ROLE_DELIMITER); String domain = commonName.substring(0, delimiterIndex); String roleName = commonName.substring(delimiterIndex + COMMON_NAME_ROLE_DELIMITER.length()); diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10Csr.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10Csr.java index e0029681b23..8138be9d7d8 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10Csr.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10Csr.java @@ -19,7 +19,9 @@ import static java.util.stream.Collectors.toList; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class Pkcs10Csr { private final PKCS10CertificationRequest csr; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilder.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilder.java index 2135f569aeb..702b2f6cd4b 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilder.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilder.java @@ -24,7 +24,9 @@ import static com.yahoo.vespa.athenz.tls.SubjectAlternativeName.Type.DNS_NAME; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class Pkcs10CsrBuilder { private final X500Principal subject; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtils.java index 2289c9ac0ee..be7bb3690bd 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtils.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtils.java @@ -13,7 +13,9 @@ import java.io.UncheckedIOException; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class Pkcs10CsrUtils { private Pkcs10CsrUtils() {} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SignatureAlgorithm.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SignatureAlgorithm.java index 2f3e2721751..1ff8ebbe78a 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SignatureAlgorithm.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SignatureAlgorithm.java @@ -3,7 +3,9 @@ package com.yahoo.vespa.athenz.tls; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public enum SignatureAlgorithm { SHA256_WITH_RSA("SHA256withRSA"), SHA512_WITH_RSA("SHA512withRSA"); diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SubjectAlternativeName.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SubjectAlternativeName.java index 8b89fc6fe7f..f5b0c7aa1c6 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SubjectAlternativeName.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SubjectAlternativeName.java @@ -15,7 +15,9 @@ import static java.util.stream.Collectors.toList; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class SubjectAlternativeName { private final Type type; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/AthenzIdentities.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/AthenzIdentities.java index 82aecc62306..5e01d0cddfc 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/AthenzIdentities.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/AthenzIdentities.java @@ -1,11 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.utils; +import com.yahoo.security.X509CertificateUtils; import com.yahoo.vespa.athenz.api.AthenzDomain; import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.api.AthenzUser; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; import java.security.cert.X509Certificate; import java.util.List; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/SiaUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/SiaUtils.java index 05459e5488b..98d9061be02 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/SiaUtils.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/SiaUtils.java @@ -2,8 +2,8 @@ package com.yahoo.vespa.athenz.utils; import com.yahoo.vespa.athenz.api.AthenzService; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateUtils; import java.io.IOException; import java.io.UncheckedIOException; diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identity/SiaIdentityProviderTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identity/SiaIdentityProviderTest.java index 7b93ffb035d..6217d6fb2ee 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identity/SiaIdentityProviderTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identity/SiaIdentityProviderTest.java @@ -1,15 +1,14 @@ package com.yahoo.vespa.athenz.identity; -import com.google.common.io.Files; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyStoreUtils; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SignatureAlgorithm; +import com.yahoo.security.X509CertificateBuilder; +import com.yahoo.security.X509CertificateUtils; import com.yahoo.vespa.athenz.api.AthenzService; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyStoreBuilder; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.athenz.tls.KeyStoreUtils; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.tls.SignatureAlgorithm; -import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -17,7 +16,8 @@ import org.junit.rules.TemporaryFolder; import javax.security.auth.x500.X500Principal; import java.io.File; import java.io.IOException; -import java.nio.charset.StandardCharsets; +import java.math.BigInteger; +import java.nio.file.Files; import java.security.KeyPair; import java.security.KeyStore; import java.security.cert.X509Certificate; @@ -62,12 +62,12 @@ public class SiaIdentityProviderTest { private void createPrivateKeyFile(File keyFile, KeyPair keypair) throws IOException { String privateKeyPem = KeyUtils.toPem(keypair.getPrivate()); - Files.write(privateKeyPem, keyFile, StandardCharsets.UTF_8); + Files.write(keyFile.toPath(), privateKeyPem.getBytes()); } private void createCertificateFile(X509Certificate certificate, File certificateFile) throws IOException { String certificatePem = X509CertificateUtils.toPem(certificate); - Files.write(certificatePem, certificateFile, StandardCharsets.UTF_8); + Files.write(certificateFile.toPath(), certificatePem.getBytes()); } private X509Certificate createCertificate(KeyPair keypair) { @@ -79,7 +79,7 @@ public class SiaIdentityProviderTest { now, now.plus(Duration.ofDays(1)), SignatureAlgorithm.SHA256_WITH_RSA, - 1) + BigInteger.ONE) .build(); } @@ -87,7 +87,7 @@ public class SiaIdentityProviderTest { KeyStore keystore = KeyStoreBuilder.withType(KeyStoreType.JKS) .withCertificateEntry("dummy-cert", certificate) .build(); - KeyStoreUtils.writeKeyStoreToFile(keystore, trustStoreFile); + KeyStoreUtils.writeKeyStoreToFile(keystore, trustStoreFile.toPath()); } }
\ No newline at end of file diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java index 38483bdbaee..4ad58a766e8 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java @@ -1,12 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.identityprovider.client; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.identityprovider.api.IdentityType; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyUtils; import org.junit.Test; import java.security.KeyPair; diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/KeyUtilsTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/KeyUtilsTest.java deleted file mode 100644 index fbdc6f1e3bd..00000000000 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/KeyUtilsTest.java +++ /dev/null @@ -1,36 +0,0 @@ -package com.yahoo.vespa.athenz.tls; - -import org.junit.Test; - -import java.security.KeyPair; -import java.security.PrivateKey; -import java.security.PublicKey; - -import static org.hamcrest.CoreMatchers.containsString; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThat; - -/** - * @author bjorncs - */ -public class KeyUtilsTest { - - @Test - public void can_extract_public_key_from_private() { - KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); - PublicKey publicKey = KeyUtils.extractPublicKey(keyPair.getPrivate()); - assertNotNull(publicKey); - } - - @Test - public void can_serialize_deserialize_pem() { - KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); - String pem = KeyUtils.toPem(keyPair.getPrivate()); - assertThat(pem, containsString("BEGIN RSA PRIVATE KEY")); - assertThat(pem, containsString("END RSA PRIVATE KEY")); - PrivateKey deserializedKey = KeyUtils.fromPemEncodedPrivateKey(pem); - assertEquals(keyPair.getPrivate(), deserializedKey); - } - -}
\ No newline at end of file diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilderTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilderTest.java index e3aaba66efe..3a00ad6a7a4 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilderTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilderTest.java @@ -1,5 +1,7 @@ package com.yahoo.vespa.athenz.tls; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; import org.junit.Test; import javax.security.auth.x500.X500Principal; diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrTest.java index ea60511f39c..8213856512d 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrTest.java @@ -1,5 +1,8 @@ package com.yahoo.vespa.athenz.tls; +import com.yahoo.security.Extension; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; import org.junit.Test; import javax.security.auth.x500.X500Principal; @@ -48,7 +51,7 @@ public class Pkcs10CsrTest { .addSubjectAlternativeName("san") .setBasicConstraints(true, true) .build(); - List<String> expected = Arrays.asList(Extension.BASIC_CONSTRAINS.getOId(), Extension.SUBJECT_ALTERNATIVE_NAMES.getOId()); + List<String> expected = Arrays.asList(Extension.BASIC_CONSTRAINTS.getOId(), Extension.SUBJECT_ALTERNATIVE_NAMES.getOId()); List<String> actual = csr.getExtensionOIds(); assertEquals(expected, actual); } diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtilsTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtilsTest.java index 5b5a57f1fcc..fcbc6d00a8e 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtilsTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtilsTest.java @@ -1,5 +1,7 @@ package com.yahoo.vespa.athenz.tls; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; import org.junit.Test; import javax.security.auth.x500.X500Principal; diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/TestUtils.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/TestUtils.java index 2a9b54f9e9e..048538c1a33 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/TestUtils.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/TestUtils.java @@ -1,15 +1,21 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.tls; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateBuilder; + import javax.security.auth.x500.X500Principal; -import java.io.File; +import java.math.BigInteger; import java.security.KeyPair; import java.security.KeyStore; import java.security.cert.X509Certificate; import java.time.Instant; import java.time.temporal.ChronoUnit; -import static com.yahoo.vespa.athenz.tls.KeyStoreUtils.writeKeyStoreToFile; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_RSA; /** * @author bjorncs @@ -30,11 +36,8 @@ class TestUtils { static X509Certificate createCertificate(KeyPair keyPair, X500Principal subject) { return X509CertificateBuilder .fromKeypair( - keyPair, subject, Instant.now(), Instant.now().plus(1, ChronoUnit.DAYS), SignatureAlgorithm.SHA256_WITH_RSA, 1) + keyPair, subject, Instant.now(), Instant.now().plus(1, ChronoUnit.DAYS), SHA256_WITH_RSA, BigInteger.ONE) .build(); } - static void createKeystoreFile(File file, KeyStoreType type, char[] password) { - writeKeyStoreToFile(createKeystore(type, password), file, password); - } } diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/X509CertificateBuilderTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/X509CertificateBuilderTest.java deleted file mode 100644 index 81ff4fdb208..00000000000 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/X509CertificateBuilderTest.java +++ /dev/null @@ -1,58 +0,0 @@ -package com.yahoo.vespa.athenz.tls; - -import org.junit.Test; - -import javax.security.auth.x500.X500Principal; -import java.security.KeyPair; -import java.security.NoSuchAlgorithmException; -import java.security.cert.X509Certificate; -import java.time.Instant; -import java.time.temporal.ChronoUnit; - -import static org.junit.Assert.assertEquals; - -/** - * @author bjorncs - */ -public class X509CertificateBuilderTest { - - @Test - public void can_build_self_signed_certificate() throws NoSuchAlgorithmException { - KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 2048); - X500Principal subject = new X500Principal("CN=myservice"); - X509Certificate cert = - X509CertificateBuilder.fromKeypair( - keyPair, - subject, - Instant.now(), - Instant.now().plus(1, ChronoUnit.DAYS), - SignatureAlgorithm.SHA256_WITH_RSA, - 1) - .setBasicConstraints(true, true) - .build(); - assertEquals(subject, cert.getSubjectX500Principal()); - } - - @Test - public void can_build_certificate_from_csr() { - X500Principal subject = new X500Principal("CN=subject"); - X500Principal issuer = new X500Principal("CN=issuer"); - KeyPair csrKeypair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 2048); - Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, csrKeypair, SignatureAlgorithm.SHA256_WITH_RSA).build(); - KeyPair caKeypair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 2048); - X509Certificate cert = X509CertificateBuilder - .fromCsr( - csr, - issuer, - Instant.now(), - Instant.now().plus(1, ChronoUnit.DAYS), - caKeypair.getPrivate(), - SignatureAlgorithm.SHA256_WITH_RSA, - 1) - .addSubjectAlternativeName("subject1.alt") - .addSubjectAlternativeName("subject2.alt") - .build(); - assertEquals(subject, cert.getSubjectX500Principal()); - } - -}
\ No newline at end of file diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java index 73382d267be..679476abe12 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java @@ -1,24 +1,25 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.utils; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateBuilder; import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier; -import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; import org.junit.Test; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; import java.security.KeyPair; -import java.security.KeyPairGenerator; -import java.security.NoSuchAlgorithmException; import java.security.cert.Certificate; import java.security.cert.X509Certificate; import java.time.Duration; import java.time.Instant; -import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_ECDSA; import static java.util.Collections.singleton; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -34,23 +35,17 @@ public class AthenzIdentityVerifierTest { public void verifies_certificate_with_athenz_service_as_common_name() throws Exception { AthenzIdentity trustedIdentity = new AthenzService("mydomain", "alice"); AthenzIdentity unknownIdentity = new AthenzService("mydomain", "mallory"); - KeyPair keyPair = createKeyPair(); + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); AthenzIdentityVerifier verifier = new AthenzIdentityVerifier(singleton(trustedIdentity)); assertTrue(verifier.verify("hostname", createSslSessionMock(createSelfSignedCertificate(keyPair, trustedIdentity)))); assertFalse(verifier.verify("hostname", createSslSessionMock(createSelfSignedCertificate(keyPair, unknownIdentity)))); } - private static KeyPair createKeyPair() throws NoSuchAlgorithmException { - KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); - keyGen.initialize(512); - return keyGen.generateKeyPair(); - } - private static X509Certificate createSelfSignedCertificate(KeyPair keyPair, AthenzIdentity identity) { X500Principal x500Name = new X500Principal("CN="+ identity.getFullName()); Instant now = Instant.now(); return X509CertificateBuilder - .fromKeypair(keyPair, x500Name, now, now.plus(Duration.ofDays(30)), SHA256_WITH_RSA, 1) + .fromKeypair(keyPair, x500Name, now, now.plus(Duration.ofDays(30)), SHA256_WITH_ECDSA, BigInteger.ONE) .setBasicConstraints(true, true) .build(); } diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/ntoken/NTokenValidatorTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/ntoken/NTokenValidatorTest.java index 22f97ca8b60..750968a437e 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/ntoken/NTokenValidatorTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/ntoken/NTokenValidatorTest.java @@ -6,8 +6,8 @@ import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzPrincipal; import com.yahoo.vespa.athenz.api.AthenzUser; import com.yahoo.vespa.athenz.api.NToken; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyUtils; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; import com.yahoo.vespa.athenz.utils.ntoken.NTokenValidator.InvalidTokenException; import org.junit.Rule; import org.junit.Test; diff --git a/vespa-documentgen-plugin/etc/complex/book.sd b/vespa-documentgen-plugin/etc/complex/book.sd index 2635ebe9881..16bf4447979 100644 --- a/vespa-documentgen-plugin/etc/complex/book.sd +++ b/vespa-documentgen-plugin/etc/complex/book.sd @@ -30,7 +30,7 @@ search book { attribute: prefetch } - field mynestedwsfloat type weightedset<weightedset<float>> {} + field mynestedwsfloat type weightedset<float> {} field myarrayint type array<int> { indexing: attribute diff --git a/vespa-documentgen-plugin/etc/complex/common.sd b/vespa-documentgen-plugin/etc/complex/common.sd index 0764421ac8d..e0505eba05b 100644 --- a/vespa-documentgen-plugin/etc/complex/common.sd +++ b/vespa-documentgen-plugin/etc/complex/common.sd @@ -17,19 +17,13 @@ search common { indexing: summary } field weight type float { - indexing { - input weight * 10 | attribute | summary; - } + indexing: attribute | summary } field w1 type float { - indexing { - input weight * 6 + input w1 | summary; - } + indexing: summary } field w2 type float { - indexing { - input w2 + input weight | summary; - } + indexing: summary } field did type string { indexing: attribute|index|summary diff --git a/vespa-documentgen-plugin/etc/complex/common2.sd b/vespa-documentgen-plugin/etc/complex/common2.sd new file mode 100644 index 00000000000..e32d3ed6751 --- /dev/null +++ b/vespa-documentgen-plugin/etc/complex/common2.sd @@ -0,0 +1,9 @@ +# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +search common2 { + document { + field com2 type string { + + } + } +} + diff --git a/vespa-documentgen-plugin/etc/complex/music2.sd b/vespa-documentgen-plugin/etc/complex/music2.sd index 5657580e622..2e2d96ecdec 100644 --- a/vespa-documentgen-plugin/etc/complex/music2.sd +++ b/vespa-documentgen-plugin/etc/complex/music2.sd @@ -56,7 +56,7 @@ search music2 { } field didinteger type array<int> { - indexing: input did | split " " | attribute + indexing: input did | split " " | for_each { to_int } | attribute } rank-profile default { diff --git a/vespa-documentgen-plugin/etc/complex/music3.sd b/vespa-documentgen-plugin/etc/complex/music3.sd new file mode 100644 index 00000000000..65f37029d04 --- /dev/null +++ b/vespa-documentgen-plugin/etc/complex/music3.sd @@ -0,0 +1,8 @@ +# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +search music3 { + document music3 inherits music2, common2 { + field mu3 type string { + + } + } +} diff --git a/vespa-documentgen-plugin/etc/localapp/common.sd b/vespa-documentgen-plugin/etc/localapp/common.sd index ada7ce7436a..724897b4e7f 100644 --- a/vespa-documentgen-plugin/etc/localapp/common.sd +++ b/vespa-documentgen-plugin/etc/localapp/common.sd @@ -17,19 +17,13 @@ search common { indexing: summary } field weight type float { - indexing { - input weight * 10 | attribute | summary; - } + indexing: attribute|summary } field w1 type float { - indexing { - input weight * 6 + input w1 | summary; - } + indexing: summary } field w2 type float { - indexing { - input w2 + input weight | summary; - } + indexing: summary } field did type string { indexing: attribute|index|summary diff --git a/vespa-documentgen-plugin/etc/localapp/music.sd b/vespa-documentgen-plugin/etc/localapp/music.sd index 0cfe5cf923a..e00e046f511 100644 --- a/vespa-documentgen-plugin/etc/localapp/music.sd +++ b/vespa-documentgen-plugin/etc/localapp/music.sd @@ -51,7 +51,7 @@ search music { } field didinteger type array<int> { - indexing: input did | split " " | attribute + indexing: input did | split " " | for_each { to_int } | attribute } rank-profile default { diff --git a/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java b/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java index eab3983dc69..acefa3fa461 100644 --- a/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java +++ b/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa; +import com.yahoo.collections.Pair; import com.yahoo.document.*; import com.yahoo.document.annotation.AnnotationReferenceDataType; import com.yahoo.document.annotation.AnnotationType; @@ -109,7 +110,7 @@ public class DocumentGenMojo extends AbstractMojo { public boolean accept(File dir, String name) { return name.endsWith(".sd"); }}); - SearchBuilder builder = new UnprocessingSearchBuilder(); + SearchBuilder builder = new SearchBuilder(); for (File f : sdFiles) { try { long modTime = f.lastModified(); @@ -405,7 +406,9 @@ public class DocumentGenMojo extends AbstractMojo { */ private void exportDocumentClass(NewDocumentType docType, Writer out, String packageName) throws IOException { String className = className(docType.getName()); - String superType = javaSuperType(docType); + Pair<String, Boolean> extendInfo = javaSuperType(docType); + String superType = extendInfo.getFirst(); + Boolean multiExtends = extendInfo.getSecond(); out.write( "package "+packageName+";\n\n" + exportInnerImportsFromSuperTypes(docType, packageName) + @@ -441,12 +444,14 @@ public class DocumentGenMojo extends AbstractMojo { exportStructTypeGetter(docType.getName()+".header", docType.allHeader().getFields(), out, 1, "getHeaderStructType", "com.yahoo.document.StructDataType"); exportStructTypeGetter(docType.getName()+".body", docType.allBody().getFields(), out, 1, "getBodyStructType", "com.yahoo.document.StructDataType"); - exportExtendedStructTypeGetter(className, docType.getName(), docType.getAllFields(), out, 1, "getDocumentType", "com.yahoo.document.DocumentType"); - exportCopyConstructor(className, docType.getAllFields(), out, 1, true); - exportFieldsAndAccessors(className, "com.yahoo.document.Document".equals(superType) ? docType.getAllFields() : docType.getFields(), out, 1, true); - exportDocumentMethods(docType.getAllFields(), out, 1); - exportHashCode(docType.getAllFields(), out, 1, "(getDataType() != null ? getDataType().hashCode() : 0) + getId().hashCode()"); - exportEquals(className, docType.getAllFields(), out, 1); + Collection<Field> allUniqueFields = getAllUniqueFields(multiExtends, docType.getAllFields()); + exportExtendedStructTypeGetter(className, docType.getName(), allUniqueFields, out, 1, "getDocumentType", "com.yahoo.document.DocumentType"); + exportCopyConstructor(className, allUniqueFields, out, 1, true); + + exportFieldsAndAccessors(className, "com.yahoo.document.Document".equals(superType) ? allUniqueFields : docType.getFields(), out, 1, true); + exportDocumentMethods(allUniqueFields, out, 1); + exportHashCode(allUniqueFields, out, 1, "(getDataType() != null ? getDataType().hashCode() : 0) + getId().hashCode()"); + exportEquals(className, allUniqueFields, out, 1); Set<DataType> exportedStructs = exportStructTypes(docType.getTypes(), out, 1, null); docTypes.put(docType.getName(), packageName+"."+className); for (DataType exportedStruct : exportedStructs) { @@ -455,15 +460,36 @@ public class DocumentGenMojo extends AbstractMojo { out.write("}\n"); } + private Collection<Field> getAllUniqueFields(Boolean multipleInheritance, Collection<Field> allFields) { + if (multipleInheritance) { + Map<String, Field> seen = new HashMap<>(); + List<Field> unique = new ArrayList<>(allFields.size()); + for (Field f : allFields) { + if (seen.containsKey(f.getName())) { + if ( ! f.equals(seen.get(f.getName()))) { + throw new IllegalArgumentException("Field '" + f.getName() + "' has conflicting definitions in multiple inheritance." + + "First defined as '" + seen.get(f.getName()) + "', then as '" + f + "'."); + } + } else { + unique.add(f); + seen.put(f.getName(), f); + } + } + return unique; + } + return allFields; + } + /** * The Java class the class of the given type should inherit from. If the input type inherits from _one_ * other type, use that, otherwise Document. */ - private static String javaSuperType(NewDocumentType docType) { + private static Pair<String,Boolean> javaSuperType(NewDocumentType docType) { String ret = "com.yahoo.document.Document"; Collection<NewDocumentType> specInheriteds = specificInheriteds(docType); - if (!specInheriteds.isEmpty() && singleInheritance(specInheriteds)) ret = className(specInheriteds.iterator().next().getName()); - return ret; + boolean singleExtends = singleInheritance(specInheriteds); + if (!specInheriteds.isEmpty() && singleExtends) ret = className(specInheriteds.iterator().next().getName()); + return new Pair<>(ret, !singleExtends); } private static boolean singleInheritance(Collection<NewDocumentType> specInheriteds) { diff --git a/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java b/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java index a9a5893cf96..b21f38c586a 100644 --- a/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java +++ b/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java @@ -33,6 +33,7 @@ public class DocumentGenTest { mojo.execute(new File("etc/complex/"), new File("target/generated-test-sources/vespa-documentgen-plugin/"), "com.yahoo.vespa.document"); Map<String, Search> searches = mojo.getSearches(); assertEquals(searches.get("video").getDocument("video").getField("weight").getDataType(), DataType.FLOAT); + assertEquals(searches.get("book").getDocument("book").getField("sw1").getDataType(), DataType.FLOAT); assertTrue(searches.get("book").getDocument("book").getField("mystruct").getDataType() instanceof StructDataType); assertTrue(searches.get("book").getDocument("book").getField("mywsfloat").getDataType() instanceof WeightedSetDataType); assertTrue(((WeightedSetDataType)(searches.get("book").getDocument("book").getField("mywsfloat").getDataType())).getNestedType() == DataType.FLOAT); diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/StatusResponse.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/StatusResponse.java index c6b5a6cb4fe..38f9e229726 100755 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/StatusResponse.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/StatusResponse.java @@ -11,6 +11,10 @@ import java.io.IOException; import java.io.OutputStream; import java.io.OutputStreamWriter; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class StatusResponse extends HttpResponse { MetricManager manager; diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerCompatibility.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerCompatibility.java index b1a7b6dbdeb..dc23589f8ed 100755 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerCompatibility.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerCompatibility.java @@ -9,6 +9,10 @@ import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class VespaFeedHandlerCompatibility extends ThreadedHttpRequestHandler { private final VespaFeedHandlerGet getHandler; diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerGet.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerGet.java index 70631e0e66c..ed4750148bd 100755 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerGet.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerGet.java @@ -11,6 +11,10 @@ import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; import com.yahoo.search.handler.SearchHandler; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class VespaFeedHandlerGet extends ThreadedHttpRequestHandler { private final SearchHandler searchHandler; diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemove.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemove.java index 36ab8090e95..4673efb4605 100755 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemove.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemove.java @@ -21,6 +21,10 @@ import java.io.BufferedReader; import java.io.InputStreamReader; import java.util.concurrent.Executor; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class VespaFeedHandlerRemove extends VespaFeedHandlerBase { @Inject diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemoveLocation.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemoveLocation.java index 04ca6798b4c..ecb911953f6 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemoveLocation.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemoveLocation.java @@ -20,6 +20,10 @@ import com.yahoo.vespaclient.config.FeederConfig; import java.util.concurrent.Executor; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class VespaFeedHandlerRemoveLocation extends VespaFeedHandlerBase { @Inject diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerStatus.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerStatus.java index 94ad18fbb51..8c07ea30312 100755 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerStatus.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerStatus.java @@ -15,6 +15,10 @@ import com.yahoo.metrics.MetricManager; import com.yahoo.metrics.MetricSet; import com.yahoo.vespaclient.config.FeederConfig; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class VespaFeedHandlerStatus extends ThreadedHttpRequestHandler { private MetricManager manager; diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerVisit.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerVisit.java index c9af0933799..5b5224775cb 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerVisit.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerVisit.java @@ -13,7 +13,10 @@ import com.yahoo.search.handler.SearchHandler; /** * @author thomasg + * + * @deprecated Legacy API. Will be removed in Vespa 7 */ +@Deprecated public class VespaFeedHandlerVisit extends ThreadedHttpRequestHandler { private final SearchHandler searchHandler; diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/feedhandler/VespaFeedHandlerTestCase.java b/vespaclient-container-plugin/src/test/java/com/yahoo/feedhandler/VespaFeedHandlerTestCase.java index d1ed02209b2..fcc4e18d66e 100755 --- a/vespaclient-container-plugin/src/test/java/com/yahoo/feedhandler/VespaFeedHandlerTestCase.java +++ b/vespaclient-container-plugin/src/test/java/com/yahoo/feedhandler/VespaFeedHandlerTestCase.java @@ -39,6 +39,7 @@ import java.util.zip.GZIPOutputStream; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +@SuppressWarnings("deprecation") // VespaFeedHandler classes are going away on Vespa 7 public class VespaFeedHandlerTestCase { private VespaFeedHandler feedHandler; diff --git a/vespaclient/src/perl/lib/Yahoo/Vespa/VespaModel.pm b/vespaclient/src/perl/lib/Yahoo/Vespa/VespaModel.pm index fd324540bba..18af0bbdecd 100644 --- a/vespaclient/src/perl/lib/Yahoo/Vespa/VespaModel.pm +++ b/vespaclient/src/perl/lib/Yahoo/Vespa/VespaModel.pm @@ -162,7 +162,7 @@ sub setModelRetrievalFunction { # (Function) } sub retrieveModelConfigDefault { # () my $VESPA_HOME= $ENV{'VESPA_HOME'}; - my $cmd = ${VESPA_HOME} . '/bin/vespa-get-config -n cloud.config.model -i admin/model'; + my $cmd = ${VESPA_HOME} . '/bin/vespa-get-config -l -n cloud.config.model -i admin/model'; if (defined $CONFIG_REQUEST_TIMEOUT) { $cmd .= " -w $CONFIG_REQUEST_TIMEOUT"; diff --git a/vespajlib/pom.xml b/vespajlib/pom.xml index 880d039bc54..5b9c143a447 100644 --- a/vespajlib/pom.xml +++ b/vespajlib/pom.xml @@ -17,29 +17,27 @@ </description> <dependencies> - <dependency> - <groupId>com.google.guava</groupId> - <artifactId>guava</artifactId> - <scope>provided</scope> - </dependency> + + <!-- compile scope --> <dependency> <groupId>net.jpountz.lz4</groupId> <artifactId>lz4</artifactId> </dependency> <dependency> - <groupId>org.hamcrest</groupId> - <artifactId>hamcrest-library</artifactId> - <scope>test</scope> + <groupId>commons-lang</groupId> + <artifactId>commons-lang</artifactId> </dependency> <dependency> - <groupId>org.mockito</groupId> - <artifactId>mockito-core</artifactId> - <scope>test</scope> + <groupId>org.apache.commons</groupId> + <artifactId>commons-exec</artifactId> </dependency> + + + <!-- provided scope --> <dependency> - <groupId>junit</groupId> - <artifactId>junit</artifactId> - <scope>test</scope> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <scope>provided</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> @@ -54,22 +52,41 @@ <scope>provided</scope> </dependency> <dependency> - <groupId>commons-lang</groupId> - <artifactId>commons-lang</artifactId> + <groupId>org.bouncycastle</groupId> + <artifactId>bcprov-jdk15on</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.bouncycastle</groupId> + <artifactId>bcpkix-jdk15on</artifactId> + <scope>provided</scope> </dependency> <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-core</artifactId> - <scope>test</scope> + <scope>provided</scope> </dependency> <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-databind</artifactId> + <scope>provided</scope> + </dependency> + + <!-- test scope --> + <dependency> + <groupId>org.hamcrest</groupId> + <artifactId>hamcrest-library</artifactId> <scope>test</scope> </dependency> <dependency> - <groupId>org.apache.commons</groupId> - <artifactId>commons-exec</artifactId> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <scope>test</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> @@ -77,6 +94,7 @@ <version>${project.version}</version> <scope>test</scope> </dependency> + </dependencies> <build> <plugins> diff --git a/vespajlib/src/main/java/com/yahoo/security/BasicConstraintsExtension.java b/vespajlib/src/main/java/com/yahoo/security/BasicConstraintsExtension.java new file mode 100644 index 00000000000..d3c08ba27d0 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/BasicConstraintsExtension.java @@ -0,0 +1,14 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +/** + * @author bjorncs + */ +class BasicConstraintsExtension { + final boolean isCritical, isCertAuthorityCertificate; + + BasicConstraintsExtension(boolean isCritical, boolean isCertAuthorityCertificate) { + this.isCritical = isCritical; + this.isCertAuthorityCertificate = isCertAuthorityCertificate; + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/BouncyCastleProviderHolder.java b/vespajlib/src/main/java/com/yahoo/security/BouncyCastleProviderHolder.java new file mode 100644 index 00000000000..48a23a1fe7e --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/BouncyCastleProviderHolder.java @@ -0,0 +1,14 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.jce.provider.BouncyCastleProvider; + +/** + * @author bjorncs + */ +class BouncyCastleProviderHolder { + + private static final BouncyCastleProvider bcProvider = new BouncyCastleProvider(); + + static BouncyCastleProvider getInstance() { return bcProvider; } +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Extension.java b/vespajlib/src/main/java/com/yahoo/security/Extension.java index 18403669c4d..46b781c9c86 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Extension.java +++ b/vespajlib/src/main/java/com/yahoo/security/Extension.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.athenz.tls; +package com.yahoo.security; import org.bouncycastle.asn1.ASN1ObjectIdentifier; @@ -7,7 +7,7 @@ import org.bouncycastle.asn1.ASN1ObjectIdentifier; * @author bjorncs */ public enum Extension { - BASIC_CONSTRAINS(org.bouncycastle.asn1.x509.Extension.basicConstraints), + BASIC_CONSTRAINTS(org.bouncycastle.asn1.x509.Extension.basicConstraints), SUBJECT_ALTERNATIVE_NAMES(org.bouncycastle.asn1.x509.Extension.subjectAlternativeName); final ASN1ObjectIdentifier extensionOId; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyAlgorithm.java b/vespajlib/src/main/java/com/yahoo/security/KeyAlgorithm.java index 4c4198adaac..3218f81f0d6 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyAlgorithm.java +++ b/vespajlib/src/main/java/com/yahoo/security/KeyAlgorithm.java @@ -1,13 +1,14 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.athenz.tls; +package com.yahoo.security; /** * @author bjorncs */ public enum KeyAlgorithm { - RSA("RSA"); + RSA("RSA"), + EC("EC"); - private final String algorithmName; + final String algorithmName; KeyAlgorithm(String algorithmName) { this.algorithmName = algorithmName; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreBuilder.java b/vespajlib/src/main/java/com/yahoo/security/KeyStoreBuilder.java index a9279f45129..2160fbf6455 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreBuilder.java +++ b/vespajlib/src/main/java/com/yahoo/security/KeyStoreBuilder.java @@ -1,12 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.athenz.tls; +package com.yahoo.security; import java.io.BufferedInputStream; -import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.security.GeneralSecurityException; import java.security.KeyStore; import java.security.PrivateKey; @@ -26,7 +26,7 @@ public class KeyStoreBuilder { private final List<CertificateEntry> certificateEntries = new ArrayList<>(); private final KeyStoreType keyStoreType; - private File inputFile; + private Path inputFile; private char[] inputFilePassword; private KeyStoreBuilder(KeyStoreType keyStoreType) { @@ -37,13 +37,13 @@ public class KeyStoreBuilder { return new KeyStoreBuilder(type); } - public KeyStoreBuilder fromFile(File file, char[] password) { + public KeyStoreBuilder fromFile(Path file, char[] password) { this.inputFile = file; this.inputFilePassword = password; return this; } - public KeyStoreBuilder fromFile(File file) { + public KeyStoreBuilder fromFile(Path file) { return fromFile(file, null); } @@ -73,7 +73,7 @@ public class KeyStoreBuilder { try { KeyStore keystore = this.keyStoreType.createKeystore(); if (this.inputFile != null) { - try (InputStream in = new BufferedInputStream(new FileInputStream(this.inputFile))) { + try (InputStream in = new BufferedInputStream(Files.newInputStream(this.inputFile))) { keystore.load(in, this.inputFilePassword); } } else { diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreType.java b/vespajlib/src/main/java/com/yahoo/security/KeyStoreType.java index 6c08a60ff5b..7fb8df35286 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreType.java +++ b/vespajlib/src/main/java/com/yahoo/security/KeyStoreType.java @@ -1,7 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.athenz.tls; - -import org.bouncycastle.jce.provider.BouncyCastleProvider; +package com.yahoo.security; import java.security.GeneralSecurityException; import java.security.KeyStore; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreUtils.java b/vespajlib/src/main/java/com/yahoo/security/KeyStoreUtils.java index 12aaa40cce4..f0c4d99bf69 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreUtils.java +++ b/vespajlib/src/main/java/com/yahoo/security/KeyStoreUtils.java @@ -1,12 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.athenz.tls; +package com.yahoo.security; import java.io.BufferedOutputStream; -import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.security.GeneralSecurityException; import java.security.KeyStore; @@ -16,8 +16,8 @@ import java.security.KeyStore; public class KeyStoreUtils { private KeyStoreUtils() {} - public static void writeKeyStoreToFile(KeyStore keyStore, File file, char[] password) { - try (OutputStream out = new BufferedOutputStream(new FileOutputStream(file))) { + public static void writeKeyStoreToFile(KeyStore keyStore, Path file, char[] password) { + try (OutputStream out = new BufferedOutputStream(Files.newOutputStream(file))) { keyStore.store(out, password); } catch (IOException e) { throw new UncheckedIOException(e); @@ -27,7 +27,7 @@ public class KeyStoreUtils { } - public static void writeKeyStoreToFile(KeyStore keyStore, File file) { + public static void writeKeyStoreToFile(KeyStore keyStore, Path file) { writeKeyStoreToFile(keyStore, file, new char[0]); } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyUtils.java b/vespajlib/src/main/java/com/yahoo/security/KeyUtils.java index c2be1a40893..11fb0f432e4 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyUtils.java +++ b/vespajlib/src/main/java/com/yahoo/security/KeyUtils.java @@ -1,10 +1,14 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.athenz.tls; +package com.yahoo.security; -import com.yahoo.athenz.auth.util.Crypto; import org.bouncycastle.asn1.ASN1Encodable; import org.bouncycastle.asn1.ASN1Primitive; import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; +import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey; +import org.bouncycastle.jce.spec.ECParameterSpec; +import org.bouncycastle.jce.spec.ECPublicKeySpec; +import org.bouncycastle.math.ec.ECPoint; +import org.bouncycastle.math.ec.FixedPointCombMultiplier; import org.bouncycastle.openssl.PEMKeyPair; import org.bouncycastle.openssl.PEMParser; import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter; @@ -21,7 +25,12 @@ import java.security.KeyPair; import java.security.KeyPairGenerator; import java.security.PrivateKey; import java.security.PublicKey; +import java.security.interfaces.RSAPrivateCrtKey; import java.security.spec.PKCS8EncodedKeySpec; +import java.security.spec.RSAPublicKeySpec; + +import static com.yahoo.security.KeyAlgorithm.EC; +import static com.yahoo.security.KeyAlgorithm.RSA; /** * @author bjorncs @@ -31,7 +40,7 @@ public class KeyUtils { public static KeyPair generateKeypair(KeyAlgorithm algorithm, int keySize) { try { - KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm.getAlgorithmName()); + KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm.getAlgorithmName(), BouncyCastleProviderHolder.getInstance()); if (keySize != -1) { keyGen.initialize(keySize); } @@ -46,7 +55,26 @@ public class KeyUtils { } public static PublicKey extractPublicKey(PrivateKey privateKey) { - return Crypto.extractPublicKey(privateKey); + String algorithm = privateKey.getAlgorithm(); + try { + if (algorithm.equals(RSA.getAlgorithmName())) { + KeyFactory keyFactory = KeyFactory.getInstance(RSA.getAlgorithmName(), BouncyCastleProviderHolder.getInstance()); + RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) privateKey; + RSAPublicKeySpec keySpec = new RSAPublicKeySpec(rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent()); + return keyFactory.generatePublic(keySpec); + } else if (algorithm.equals(EC.getAlgorithmName())) { + KeyFactory keyFactory = KeyFactory.getInstance(EC.getAlgorithmName(), BouncyCastleProviderHolder.getInstance()); + BCECPrivateKey ecPrivateKey = (BCECPrivateKey) privateKey; + ECParameterSpec ecParameterSpec = ecPrivateKey.getParameters(); + ECPoint ecPoint = new FixedPointCombMultiplier().multiply(ecParameterSpec.getG(), ecPrivateKey.getD()); + ECPublicKeySpec keySpec = new ECPublicKeySpec(ecPoint, ecParameterSpec); + return keyFactory.generatePublic(keySpec); + } else { + throw new IllegalArgumentException("Unexpected key algorithm: " + algorithm); + } + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } } public static PrivateKey fromPemEncodedPrivateKey(String pem) { @@ -55,11 +83,11 @@ public class KeyUtils { if (pemObject instanceof PrivateKeyInfo) { PrivateKeyInfo keyInfo = (PrivateKeyInfo) pemObject; PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(keyInfo.getEncoded()); - return KeyFactory.getInstance(KeyAlgorithm.RSA.getAlgorithmName()).generatePrivate(keySpec); + return KeyFactory.getInstance(RSA.getAlgorithmName()).generatePrivate(keySpec); } else if (pemObject instanceof PEMKeyPair) { PEMKeyPair pemKeypair = (PEMKeyPair) pemObject; PrivateKeyInfo keyInfo = pemKeypair.getPrivateKeyInfo(); - JcaPEMKeyConverter pemConverter = new JcaPEMKeyConverter(); + JcaPEMKeyConverter pemConverter = new JcaPEMKeyConverter().setProvider(BouncyCastleProviderHolder.getInstance()); return pemConverter.getPrivateKey(keyInfo); } throw new IllegalArgumentException("Unexpected type of PEM type: " + pemObject); @@ -72,8 +100,17 @@ public class KeyUtils { public static String toPem(PrivateKey privateKey) { try (StringWriter stringWriter = new StringWriter(); JcaPEMWriter pemWriter = new JcaPEMWriter(stringWriter)) { + String algorithm = privateKey.getAlgorithm(); // Note: Encoding using PKCS#1 as this is to be read by tools only supporting PKCS#1 - pemWriter.writeObject(new PemObject("RSA PRIVATE KEY", getPkcs1Bytes(privateKey))); + String type; + if (algorithm.equals(RSA.getAlgorithmName())) { + type = "RSA PRIVATE KEY"; + } else if (algorithm.equals(EC.getAlgorithmName())) { + type = "EC PRIVATE KEY"; + } else { + throw new IllegalArgumentException("Unexpected key algorithm: " + algorithm); + } + pemWriter.writeObject(new PemObject(type, getPkcs1Bytes(privateKey))); pemWriter.flush(); return stringWriter.toString(); } catch (IOException e) { diff --git a/vespajlib/src/main/java/com/yahoo/security/Pkcs10Csr.java b/vespajlib/src/main/java/com/yahoo/security/Pkcs10Csr.java new file mode 100644 index 00000000000..e08ee117fcd --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/Pkcs10Csr.java @@ -0,0 +1,71 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.asn1.ASN1ObjectIdentifier; +import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers; +import org.bouncycastle.asn1.x509.BasicConstraints; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.Extensions; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.pkcs.PKCS10CertificationRequest; + +import javax.security.auth.x500.X500Principal; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static java.util.Collections.emptyList; +import static java.util.stream.Collectors.toList; + +/** + * @author bjorncs + */ +public class Pkcs10Csr { + + private final PKCS10CertificationRequest csr; + + Pkcs10Csr(PKCS10CertificationRequest csr) { + this.csr = csr; + } + + PKCS10CertificationRequest getBcCsr() { + return csr; + } + + public X500Principal getSubject() { + return new X500Principal(csr.getSubject().toString()); + } + + public List<SubjectAlternativeName> getSubjectAlternativeNames() { + return getExtensions() + .map(extensions -> GeneralNames.fromExtensions(extensions, Extension.subjectAlternativeName)) + .map(SubjectAlternativeName::fromGeneralNames) + .orElse(emptyList()); + } + + /** + * @return If basic constraints extension is present: returns true if CA cert, false otherwise. Returns empty if the extension is not present. + */ + public Optional<Boolean> getBasicConstraints() { + return getExtensions() + .map(BasicConstraints::fromExtensions) + .map(BasicConstraints::isCA); + } + + public List<String> getExtensionOIds() { + return getExtensions() + .map(extensions -> Arrays.stream(extensions.getExtensionOIDs()) + .map(ASN1ObjectIdentifier::getId) + .collect(toList())) + .orElse(emptyList()); + + } + + private Optional<Extensions> getExtensions() { + return Optional.of(csr.getAttributes(PKCSObjectIdentifiers.pkcs_9_at_extensionRequest)) + .filter(attributes -> attributes.length > 0) + .map(attributes -> attributes[0]) + .map(attribute -> Extensions.getInstance(attribute.getAttrValues().getObjectAt(0))); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrBuilder.java b/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrBuilder.java new file mode 100644 index 00000000000..b46293b2e2f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrBuilder.java @@ -0,0 +1,105 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers; +import org.bouncycastle.asn1.x509.BasicConstraints; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.ExtensionsGenerator; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.OperatorCreationException; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; +import org.bouncycastle.pkcs.PKCS10CertificationRequestBuilder; +import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder; + +import javax.security.auth.x500.X500Principal; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.security.KeyPair; +import java.util.ArrayList; +import java.util.List; + +import static com.yahoo.security.SubjectAlternativeName.Type.DNS_NAME; + +/** + * @author bjorncs + */ +public class Pkcs10CsrBuilder { + + private final X500Principal subject; + private final KeyPair keyPair; + private final List<SubjectAlternativeName> subjectAlternativeNames = new ArrayList<>(); + private final SignatureAlgorithm signatureAlgorithm; + private BasicConstraintsExtension basicConstraintsExtension; + + private Pkcs10CsrBuilder(X500Principal subject, + KeyPair keyPair, + SignatureAlgorithm signatureAlgorithm) { + this.subject = subject; + this.keyPair = keyPair; + this.signatureAlgorithm = signatureAlgorithm; + } + + public static Pkcs10CsrBuilder fromKeypair(X500Principal subject, + KeyPair keyPair, + SignatureAlgorithm signatureAlgorithm) { + return new Pkcs10CsrBuilder(subject, keyPair, signatureAlgorithm); + } + + public Pkcs10CsrBuilder addSubjectAlternativeName(String dns) { + this.subjectAlternativeNames.add(new SubjectAlternativeName(DNS_NAME, dns)); + return this; + } + + public Pkcs10CsrBuilder addSubjectAlternativeName(SubjectAlternativeName san) { + this.subjectAlternativeNames.add(san); + return this; + } + + public Pkcs10CsrBuilder addSubjectAlternativeName(SubjectAlternativeName.Type type, String value) { + this.subjectAlternativeNames.add(new SubjectAlternativeName(type, value)); + return this; + } + + public Pkcs10CsrBuilder setBasicConstraints(boolean isCritical, boolean isCertAuthorityCertificate) { + this.basicConstraintsExtension = new BasicConstraintsExtension(isCritical, isCertAuthorityCertificate); + return this; + } + + public Pkcs10CsrBuilder setIsCertAuthority(boolean isCertAuthority) { + return setBasicConstraints(true, isCertAuthority); + } + + public Pkcs10Csr build() { + try { + PKCS10CertificationRequestBuilder requestBuilder = + new JcaPKCS10CertificationRequestBuilder(subject, keyPair.getPublic()); + ExtensionsGenerator extGen = new ExtensionsGenerator(); + if (basicConstraintsExtension != null) { + extGen.addExtension( + Extension.basicConstraints, + basicConstraintsExtension.isCritical, + new BasicConstraints(basicConstraintsExtension.isCertAuthorityCertificate)); + } + if (!subjectAlternativeNames.isEmpty()) { + GeneralNames generalNames = new GeneralNames( + subjectAlternativeNames.stream() + .map(SubjectAlternativeName::toGeneralName) + .toArray(GeneralName[]::new)); + extGen.addExtension(Extension.subjectAlternativeName, false, generalNames); + } + requestBuilder.addAttribute(PKCSObjectIdentifiers.pkcs_9_at_extensionRequest, extGen.generate()); + ContentSigner contentSigner = new JcaContentSignerBuilder(signatureAlgorithm.getAlgorithmName()) + .setProvider(BouncyCastleProviderHolder.getInstance()) + .build(keyPair.getPrivate()); + return new Pkcs10Csr(requestBuilder.build(contentSigner)); + } catch (OperatorCreationException e) { + throw new RuntimeException(e); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrUtils.java b/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrUtils.java new file mode 100644 index 00000000000..6f12450528d --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrUtils.java @@ -0,0 +1,38 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.openssl.PEMParser; +import org.bouncycastle.openssl.jcajce.JcaPEMWriter; +import org.bouncycastle.pkcs.PKCS10CertificationRequest; +import org.bouncycastle.util.io.pem.PemObject; + +import java.io.IOException; +import java.io.StringReader; +import java.io.StringWriter; +import java.io.UncheckedIOException; + +/** + * @author bjorncs + */ +public class Pkcs10CsrUtils { + + private Pkcs10CsrUtils() {} + + public static Pkcs10Csr fromPem(String pem) { + try (PEMParser pemParser = new PEMParser(new StringReader(pem))) { + return new Pkcs10Csr((PKCS10CertificationRequest) pemParser.readObject()); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public static String toPem(Pkcs10Csr csr) { + try (StringWriter stringWriter = new StringWriter(); JcaPEMWriter pemWriter = new JcaPEMWriter(stringWriter)) { + pemWriter.writeObject(new PemObject("CERTIFICATE REQUEST", csr.getBcCsr().getEncoded())); + pemWriter.flush(); + return stringWriter.toString(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/SignatureAlgorithm.java b/vespajlib/src/main/java/com/yahoo/security/SignatureAlgorithm.java new file mode 100644 index 00000000000..fbff18f5c12 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/SignatureAlgorithm.java @@ -0,0 +1,22 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +/** + * @author bjorncs + */ +public enum SignatureAlgorithm { + SHA256_WITH_RSA("SHA256withRSA"), + SHA512_WITH_RSA("SHA512withRSA"), + SHA256_WITH_ECDSA("SHA256withECDSA"), + SHA512_WITH_ECDSA("SHA512withECDSA"); + + private final String algorithmName; + + SignatureAlgorithm(String algorithmName) { + this.algorithmName = algorithmName; + } + + public String getAlgorithmName() { + return algorithmName; + } +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SslContextBuilder.java b/vespajlib/src/main/java/com/yahoo/security/SslContextBuilder.java index ba5785043da..75ab2417edf 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SslContextBuilder.java +++ b/vespajlib/src/main/java/com/yahoo/security/SslContextBuilder.java @@ -1,12 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.athenz.tls; +package com.yahoo.security; import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; -import java.io.File; import java.io.IOException; import java.io.UncheckedIOException; import java.nio.file.Files; @@ -15,6 +14,10 @@ import java.security.GeneralSecurityException; import java.security.KeyStore; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.Collections; +import java.util.List; + +import static java.util.Collections.singletonList; /** * @author bjorncs @@ -27,7 +30,7 @@ public class SslContextBuilder { public SslContextBuilder() {} - public SslContextBuilder withTrustStore(File file, KeyStoreType trustStoreType) { + public SslContextBuilder withTrustStore(Path file, KeyStoreType trustStoreType) { this.trustStoreSupplier = () -> KeyStoreBuilder.withType(trustStoreType).fromFile(file).build(); return this; } @@ -37,6 +40,24 @@ public class SslContextBuilder { return this; } + public SslContextBuilder withTrustStore(X509Certificate caCertificate) { + return withTrustStore(singletonList(caCertificate)); + } + + public SslContextBuilder withTrustStore(List<X509Certificate> caCertificates) { + this.trustStoreSupplier = () -> createTrustStore(caCertificates); + return this; + } + + public SslContextBuilder withTrustStore(Path pemEncodedCaCertificates) { + this.trustStoreSupplier = () -> { + List<X509Certificate> caCertificates = + X509CertificateUtils.certificateListFromPem(new String(Files.readAllBytes(pemEncodedCaCertificates))); + return createTrustStore(caCertificates); + }; + return this; + } + public SslContextBuilder withKeyStore(PrivateKey privateKey, X509Certificate certificate) { char[] pwd = new char[0]; this.keyStoreSupplier = () -> KeyStoreBuilder.withType(KeyStoreType.JKS).withKeyEntry("default", privateKey, certificate).build(); @@ -50,23 +71,19 @@ public class SslContextBuilder { return this; } - public SslContextBuilder withKeyStore(File file, char[] password, KeyStoreType keyStoreType) { + public SslContextBuilder withKeyStore(Path file, char[] password, KeyStoreType keyStoreType) { this.keyStoreSupplier = () -> KeyStoreBuilder.withType(keyStoreType).fromFile(file, password).build(); this.keyStorePassword = password; return this; } - public SslContextBuilder withKeyStore(File privateKeyPemFile, File certificatePemFile) { - return withKeyStore(privateKeyPemFile.toPath(), certificatePemFile.toPath()); - } - - public SslContextBuilder withKeyStore(Path privateKeyPemFile, Path certificatePemFile) { + public SslContextBuilder withKeyStore(Path privateKeyPemFile, Path certificatesPemFile) { this.keyStoreSupplier = () -> { PrivateKey privateKey = KeyUtils.fromPemEncodedPrivateKey(new String(Files.readAllBytes(privateKeyPemFile))); - X509Certificate certificate = X509CertificateUtils.fromPem(new String(Files.readAllBytes(certificatePemFile))); + List<X509Certificate> certificates = X509CertificateUtils.certificateListFromPem(new String(Files.readAllBytes(certificatesPemFile))); return KeyStoreBuilder.withType(KeyStoreType.JKS) - .withKeyEntry("default", privateKey, certificate) + .withKeyEntry("default", privateKey, certificates) .build(); }; this.keyStorePassword = new char[0]; @@ -105,6 +122,14 @@ public class SslContextBuilder { return keyManagerFactory.getKeyManagers(); } + private static KeyStore createTrustStore(List<X509Certificate> caCertificates) { + KeyStoreBuilder trustStoreBuilder = KeyStoreBuilder.withType(KeyStoreType.JKS); + for (int i = 0; i < caCertificates.size(); i++) { + trustStoreBuilder.withCertificateEntry("cert-" + i, caCertificates.get(i)); + } + return trustStoreBuilder.build(); + } + private interface KeyStoreSupplier { KeyStore get() throws IOException, GeneralSecurityException; } diff --git a/vespajlib/src/main/java/com/yahoo/security/SubjectAlternativeName.java b/vespajlib/src/main/java/com/yahoo/security/SubjectAlternativeName.java new file mode 100644 index 00000000000..29395c75e70 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/SubjectAlternativeName.java @@ -0,0 +1,114 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.asn1.ASN1Encodable; +import org.bouncycastle.asn1.DERIA5String; +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.GeneralNames; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static java.util.stream.Collectors.toList; + +/** + * @author bjorncs + */ +public class SubjectAlternativeName { + + private final Type type; + private final String value; + + public SubjectAlternativeName(Type type, String value) { + this.type = type; + this.value = value; + } + + SubjectAlternativeName(GeneralName bcGeneralName) { + this.type = Type.fromTag(bcGeneralName.getTagNo()); + this.value = getValue(bcGeneralName); + } + + public Type getType() { + return type; + } + + public String getValue() { + return value; + } + + GeneralName toGeneralName() { + return new GeneralName(type.tag, value); + } + + static List<SubjectAlternativeName> fromGeneralNames(GeneralNames generalNames) { + return Arrays.stream(generalNames.getNames()).map(SubjectAlternativeName::new).collect(toList()); + } + + private String getValue(GeneralName bcGeneralName) { + ASN1Encodable name = bcGeneralName.getName(); + switch (bcGeneralName.getTagNo()) { + case GeneralName.rfc822Name: + case GeneralName.dNSName: + case GeneralName.uniformResourceIdentifier: + return DERIA5String.getInstance(name).getString(); + case GeneralName.directoryName: + return X500Name.getInstance(name).toString(); + default: + return name.toString(); + } + } + + @Override + public String toString() { + return "SubjectAlternativeName{" + + "type=" + type + + ", value='" + value + '\'' + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SubjectAlternativeName that = (SubjectAlternativeName) o; + return type == that.type && + Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(type, value); + } + + public enum Type { + OTHER_NAME(0), + RFC822_NAME(1), + DNS_NAME(2), + X400_ADDRESS(3), + DIRECTORY_NAME(4), + EDI_PARITY_NAME(5), + UNIFORM_RESOURCE_IDENTIFIER(6), + IP_ADDRESS(7), + REGISTERED_ID(8); + + final int tag; + + Type(int tag) { + this.tag = tag; + } + + public static Type fromTag(int tag) { + return Arrays.stream(Type.values()) + .filter(type -> type.tag == tag) + .findAny() + .orElseThrow(() -> new IllegalArgumentException("Invalid tag: " + tag)); + } + + public int getTag() { + return tag; + } + } +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateBuilder.java b/vespajlib/src/main/java/com/yahoo/security/X509CertificateBuilder.java index c27b704f6a3..54d7d39253e 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateBuilder.java +++ b/vespajlib/src/main/java/com/yahoo/security/X509CertificateBuilder.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.athenz.tls; +package com.yahoo.security; import org.bouncycastle.asn1.x509.BasicConstraints; import org.bouncycastle.asn1.x509.Extension; @@ -21,20 +21,22 @@ import java.security.GeneralSecurityException; import java.security.KeyPair; import java.security.PrivateKey; import java.security.PublicKey; +import java.security.SecureRandom; import java.security.cert.X509Certificate; import java.sql.Date; import java.time.Instant; import java.util.ArrayList; import java.util.List; -import static com.yahoo.vespa.athenz.tls.SubjectAlternativeName.Type.DNS_NAME; +import static com.yahoo.security.SubjectAlternativeName.Type.DNS_NAME; + /** * @author bjorncs */ public class X509CertificateBuilder { - private final long serialNumber; + private final BigInteger serialNumber; private final SignatureAlgorithm signingAlgorithm; private final PrivateKey caPrivateKey; private final Instant notBefore; @@ -52,7 +54,7 @@ public class X509CertificateBuilder { PublicKey certPublicKey, PrivateKey caPrivateKey, SignatureAlgorithm signingAlgorithm, - long serialNumber) { + BigInteger serialNumber) { this.issuer = issuer; this.subject = subject; this.notBefore = notBefore; @@ -69,10 +71,12 @@ public class X509CertificateBuilder { Instant notAfter, PrivateKey caPrivateKey, SignatureAlgorithm signingAlgorithm, - long serialNumber) { + BigInteger serialNumber) { try { PKCS10CertificationRequest bcCsr = csr.getBcCsr(); - PublicKey publicKey = new JcaPKCS10CertificationRequest(bcCsr).getPublicKey(); + PublicKey publicKey = new JcaPKCS10CertificationRequest(bcCsr) + .setProvider(BouncyCastleProviderHolder.getInstance()) + .getPublicKey(); return new X509CertificateBuilder(caIssuer, new X500Principal(bcCsr.getSubject().getEncoded()), notBefore, @@ -93,7 +97,7 @@ public class X509CertificateBuilder { Instant notBefore, Instant notAfter, SignatureAlgorithm signingAlgorithm, - long serialNumber) { + BigInteger serialNumber) { return new X509CertificateBuilder(subject, subject, notBefore, @@ -104,6 +108,13 @@ public class X509CertificateBuilder { serialNumber); } + /** + * @return generates a cryptographically secure positive serial number up to 128 bits + */ + public static BigInteger generateRandomSerialNumber() { + return new BigInteger(128, new SecureRandom()); + } + public X509CertificateBuilder addSubjectAlternativeName(String dnsName) { this.subjectAlternativeNames.add(new SubjectAlternativeName(DNS_NAME, dnsName)); return this; @@ -119,10 +130,14 @@ public class X509CertificateBuilder { return this; } + public X509CertificateBuilder setIsCertAuthority(boolean isCertAuthority) { + return setBasicConstraints(true, isCertAuthority); + } + public X509Certificate build() { try { JcaX509v3CertificateBuilder jcaCertBuilder = new JcaX509v3CertificateBuilder( - issuer, BigInteger.valueOf(serialNumber), Date.from(notBefore), Date.from(notAfter), subject, certPublicKey); + issuer, serialNumber, Date.from(notBefore), Date.from(notAfter), subject, certPublicKey); if (basicConstraintsExtension != null) { jcaCertBuilder.addExtension( Extension.basicConstraints, diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateUtils.java b/vespajlib/src/main/java/com/yahoo/security/X509CertificateUtils.java index d96ed17765c..33bd750bac5 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateUtils.java +++ b/vespajlib/src/main/java/com/yahoo/security/X509CertificateUtils.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.athenz.tls; +package com.yahoo.security; import org.bouncycastle.asn1.ASN1Encodable; import org.bouncycastle.asn1.ASN1OctetString; @@ -25,7 +25,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import static com.yahoo.vespa.athenz.tls.Extension.SUBJECT_ALTERNATIVE_NAMES; +import static com.yahoo.security.Extension.SUBJECT_ALTERNATIVE_NAMES; import static java.util.stream.Collectors.toList; /** diff --git a/vespajlib/src/main/java/com/yahoo/security/package-info.java b/vespajlib/src/main/java/com/yahoo/security/package-info.java new file mode 100644 index 00000000000..10a4c9c0e0d --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/package-info.java @@ -0,0 +1,9 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * @author bjorncs + */ + +@ExportPackage +package com.yahoo.security; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/vespajlib/src/main/java/com/yahoo/security/tls/TransportSecurityOptions.java b/vespajlib/src/main/java/com/yahoo/security/tls/TransportSecurityOptions.java new file mode 100644 index 00000000000..f0d1edd6889 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/tls/TransportSecurityOptions.java @@ -0,0 +1,90 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Objects; +import java.util.Optional; + +/** + * Generic TLS configuration for Vespa + * + * @author bjorncs + */ +public class TransportSecurityOptions { + + private static final ObjectMapper mapper = new ObjectMapper(); + + private final Path privateKeyFile; + private final Path certificatesFile; + private final Path caCertificatesFile; + + public TransportSecurityOptions(String privateKeyFile, String certificatesFile, String caCertificatesFile) { + this(Paths.get(privateKeyFile), Paths.get(certificatesFile), Paths.get(caCertificatesFile)); + } + + public TransportSecurityOptions(Path privateKeyFile, Path certificatesFile, Path caCertificatesFile) { + this.privateKeyFile = privateKeyFile; + this.certificatesFile = certificatesFile; + this.caCertificatesFile = caCertificatesFile; + } + + public Path getPrivateKeyFile() { + return privateKeyFile; + } + + public Path getCertificatesFile() { + return certificatesFile; + } + + public Path getCaCertificatesFile() { + return caCertificatesFile; + } + + public static TransportSecurityOptions fromJsonFile(Path file) { + try { + JsonNode root = mapper.readTree(file.toFile()); + JsonNode filesNode = getField(root, "files"); + String privateKeyFile = getField(filesNode, "private-key").asText(); + String certificatesFile = getField(filesNode, "certificates").asText(); + String caCertificatesFile = getField(filesNode, "ca-certificates").asText(); + return new TransportSecurityOptions(privateKeyFile, certificatesFile, caCertificatesFile); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static JsonNode getField(JsonNode root, String fieldName) { + return Optional.ofNullable(root.get(fieldName)) + .orElseThrow(() -> new IllegalArgumentException(String.format("'%s' field missing", fieldName))); + } + + @Override + public String toString() { + return "TransportSecurityOptions{" + + "privateKeyFile=" + privateKeyFile + + ", certificatesFile=" + certificatesFile + + ", caCertificatesFile=" + caCertificatesFile + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TransportSecurityOptions that = (TransportSecurityOptions) o; + return Objects.equals(privateKeyFile, that.privateKeyFile) && + Objects.equals(certificatesFile, that.certificatesFile) && + Objects.equals(caCertificatesFile, that.caCertificatesFile); + } + + @Override + public int hashCode() { + return Objects.hash(privateKeyFile, certificatesFile, caCertificatesFile); + } +}
\ No newline at end of file diff --git a/vespajlib/src/main/java/com/yahoo/security/tls/package-info.java b/vespajlib/src/main/java/com/yahoo/security/tls/package-info.java new file mode 100644 index 00000000000..b5668182f14 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/tls/package-info.java @@ -0,0 +1,8 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * @author bjorncs + */ +@ExportPackage +package com.yahoo.security.tls; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/KeyStoreBuilderTest.java b/vespajlib/src/test/java/com/yahoo/security/KeyStoreBuilderTest.java index 6060f6f3521..06ea5d963a3 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/KeyStoreBuilderTest.java +++ b/vespajlib/src/test/java/com/yahoo/security/KeyStoreBuilderTest.java @@ -1,15 +1,17 @@ -package com.yahoo.vespa.athenz.tls; +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import java.io.File; +import java.nio.file.Path; import java.security.KeyPair; import java.security.cert.X509Certificate; -import static com.yahoo.vespa.athenz.tls.TestUtils.createCertificate; -import static com.yahoo.vespa.athenz.tls.TestUtils.createKeystoreFile; +import static com.yahoo.security.TestUtils.createCertificate; +import static com.yahoo.security.TestUtils.createKeystoreFile; + /** * @author bjorncs @@ -23,7 +25,7 @@ public class KeyStoreBuilderTest { @Test public void can_create_jks_keystore_from_privatekey_and_certificate() throws Exception { - KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 4096); + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); X509Certificate certificate = createCertificate(keyPair); KeyStoreBuilder.withType(KeyStoreType.JKS) .withKeyEntry("key", keyPair.getPrivate(), certificate) @@ -32,7 +34,7 @@ public class KeyStoreBuilderTest { @Test public void can_build_jks_keystore_from_file() throws Exception { - File keystoreFile = tempDirectory.newFile(); + Path keystoreFile = tempDirectory.newFile().toPath(); createKeystoreFile(keystoreFile, KeyStoreType.JKS, PASSWORD); KeyStoreBuilder.withType(KeyStoreType.JKS) @@ -42,7 +44,7 @@ public class KeyStoreBuilderTest { @Test public void can_build_pcks12_keystore_from_file() throws Exception { - File keystoreFile = tempDirectory.newFile(); + Path keystoreFile = tempDirectory.newFile().toPath(); createKeystoreFile(keystoreFile, KeyStoreType.PKCS12, PASSWORD); KeyStoreBuilder.withType(KeyStoreType.PKCS12) diff --git a/vespajlib/src/test/java/com/yahoo/security/KeyUtilsTest.java b/vespajlib/src/test/java/com/yahoo/security/KeyUtilsTest.java new file mode 100644 index 00000000000..5e786654d7c --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/KeyUtilsTest.java @@ -0,0 +1,54 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; + +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.PublicKey; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; + +/** + * @author bjorncs + */ +public class KeyUtilsTest { + + @Test + public void can_extract_public_key_from_rsa_private() { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); + PublicKey publicKey = KeyUtils.extractPublicKey(keyPair.getPrivate()); + assertNotNull(publicKey); + } + + @Test + public void can_extract_public_key_from_ecdsa_private() { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); + PublicKey publicKey = KeyUtils.extractPublicKey(keyPair.getPrivate()); + assertNotNull(publicKey); + } + + @Test + public void can_serialize_and_deserialize_rsa_privatekey_using_pem_format() { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); + String pem = KeyUtils.toPem(keyPair.getPrivate()); + assertThat(pem, containsString("BEGIN RSA PRIVATE KEY")); + assertThat(pem, containsString("END RSA PRIVATE KEY")); + PrivateKey deserializedKey = KeyUtils.fromPemEncodedPrivateKey(pem); + assertEquals(keyPair.getPrivate(), deserializedKey); + } + + @Test + public void can_serialize_and_deserialize_ec_privatekey_using_pem_format() { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); + String pem = KeyUtils.toPem(keyPair.getPrivate()); + assertThat(pem, containsString("BEGIN EC PRIVATE KEY")); + assertThat(pem, containsString("END EC PRIVATE KEY")); + PrivateKey deserializedKey = KeyUtils.fromPemEncodedPrivateKey(pem); + assertEquals(keyPair.getPrivate(), deserializedKey); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrBuilderTest.java b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrBuilderTest.java new file mode 100644 index 00000000000..d51203a5cb2 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrBuilderTest.java @@ -0,0 +1,27 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; + +import javax.security.auth.x500.X500Principal; +import java.security.KeyPair; + +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class Pkcs10CsrBuilderTest { + + @Test + public void can_build_csr_with_sans() { + X500Principal subject = new X500Principal("CN=subject"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, keypair, SignatureAlgorithm.SHA512_WITH_ECDSA) + .addSubjectAlternativeName("san1.com") + .addSubjectAlternativeName("san2.com") + .build(); + assertEquals(subject, csr.getSubject()); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrTest.java b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrTest.java new file mode 100644 index 00000000000..cc1f6cc6a14 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrTest.java @@ -0,0 +1,57 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; + +import javax.security.auth.x500.X500Principal; +import java.security.KeyPair; +import java.util.Arrays; +import java.util.List; + +import static com.yahoo.security.SubjectAlternativeName.Type.DNS_NAME; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * @author bjorncs + */ +public class Pkcs10CsrTest { + + @Test + public void can_read_subject_alternative_names() { + X500Principal subject = new X500Principal("CN=subject"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + SubjectAlternativeName san1 = new SubjectAlternativeName(DNS_NAME, "san1.com"); + SubjectAlternativeName san2 = new SubjectAlternativeName(DNS_NAME, "san2.com"); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, keypair, SignatureAlgorithm.SHA512_WITH_ECDSA) + .addSubjectAlternativeName(san1) + .addSubjectAlternativeName(san2) + .build(); + assertEquals(Arrays.asList(san1, san2), csr.getSubjectAlternativeNames()); + } + + @Test + public void can_read_basic_constraints() { + X500Principal subject = new X500Principal("CN=subject"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, keypair, SignatureAlgorithm.SHA512_WITH_ECDSA) + .setBasicConstraints(true, true) + .build(); + assertTrue(csr.getBasicConstraints().isPresent()); + assertTrue(csr.getBasicConstraints().get()); + } + + @Test + public void can_read_extensions() { + X500Principal subject = new X500Principal("CN=subject"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, keypair, SignatureAlgorithm.SHA512_WITH_ECDSA) + .addSubjectAlternativeName("san") + .setBasicConstraints(true, true) + .build(); + List<String> expected = Arrays.asList(Extension.BASIC_CONSTRAINTS.getOId(), Extension.SUBJECT_ALTERNATIVE_NAMES.getOId()); + List<String> actual = csr.getExtensionOIds(); + assertEquals(expected, actual); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrUtilsTest.java b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrUtilsTest.java new file mode 100644 index 00000000000..04d35a537bb --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrUtilsTest.java @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; + +import javax.security.auth.x500.X500Principal; +import java.security.KeyPair; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +/** + * @author bjorncs + */ +public class Pkcs10CsrUtilsTest { + + @Test + public void can_deserialize_serialized_pem_csr() { + X500Principal subject = new X500Principal("CN=subject"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, keypair, SignatureAlgorithm.SHA512_WITH_ECDSA).build(); + String pem = Pkcs10CsrUtils.toPem(csr); + Pkcs10Csr deserializedCsr = Pkcs10CsrUtils.fromPem(pem); + assertThat(pem, containsString("BEGIN CERTIFICATE REQUEST")); + assertThat(pem, containsString("END CERTIFICATE REQUEST")); + assertEquals(subject, deserializedCsr.getSubject()); + } + +}
\ No newline at end of file diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/SslContextBuilderTest.java b/vespajlib/src/test/java/com/yahoo/security/SslContextBuilderTest.java index 2f750d915d4..cc269a4ef43 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/SslContextBuilderTest.java +++ b/vespajlib/src/test/java/com/yahoo/security/SslContextBuilderTest.java @@ -1,17 +1,17 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.athenz.tls; +package com.yahoo.security; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import java.io.File; +import java.nio.file.Path; import java.security.KeyPair; import java.security.cert.X509Certificate; -import static com.yahoo.vespa.athenz.tls.TestUtils.createCertificate; -import static com.yahoo.vespa.athenz.tls.TestUtils.createKeystore; -import static com.yahoo.vespa.athenz.tls.TestUtils.createKeystoreFile; +import static com.yahoo.security.TestUtils.createCertificate; +import static com.yahoo.security.TestUtils.createKeystore; +import static com.yahoo.security.TestUtils.createKeystoreFile; /** * @author bjorncs @@ -47,7 +47,7 @@ public class SslContextBuilderTest { @Test public void can_build_sslcontext_with_keystore_from_private_key_and_certificate() throws Exception { - KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 2048); + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); X509Certificate certificate = createCertificate(keyPair); new SslContextBuilder() .withKeyStore(keyPair.getPrivate(), certificate) @@ -56,7 +56,7 @@ public class SslContextBuilderTest { @Test public void can_build_sslcontext_with_jks_keystore_from_file() throws Exception { - File keystoreFile = tempDirectory.newFile(); + Path keystoreFile = tempDirectory.newFile().toPath(); createKeystoreFile(keystoreFile, KeyStoreType.JKS, PASSWORD); new SslContextBuilder() @@ -66,7 +66,7 @@ public class SslContextBuilderTest { @Test public void can_build_sslcontext_with_pcks12_keystore_from_file() throws Exception { - File keystoreFile = tempDirectory.newFile(); + Path keystoreFile = tempDirectory.newFile().toPath(); createKeystoreFile(keystoreFile, KeyStoreType.PKCS12, PASSWORD); new SslContextBuilder() diff --git a/vespajlib/src/test/java/com/yahoo/security/TestUtils.java b/vespajlib/src/test/java/com/yahoo/security/TestUtils.java new file mode 100644 index 00000000000..fcfcfb2b761 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/TestUtils.java @@ -0,0 +1,42 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.cert.X509Certificate; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import static com.yahoo.security.KeyStoreUtils.writeKeyStoreToFile; + + +/** + * @author bjorncs + */ +class TestUtils { + + static KeyStore createKeystore(KeyStoreType type, char[] password) { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + return KeyStoreBuilder.withType(type) + .withKeyEntry("entry-name", keyPair.getPrivate(), password, createCertificate(keyPair)) + .build(); + } + + static X509Certificate createCertificate(KeyPair keyPair) { + return createCertificate(keyPair, new X500Principal("CN=mysubject")); + } + + static X509Certificate createCertificate(KeyPair keyPair, X500Principal subject) { + return X509CertificateBuilder + .fromKeypair( + keyPair, subject, Instant.now(), Instant.now().plus(1, ChronoUnit.DAYS), SignatureAlgorithm.SHA512_WITH_ECDSA, BigInteger.valueOf(1)) + .build(); + } + + static void createKeystoreFile(Path file, KeyStoreType type, char[] password) { + writeKeyStoreToFile(createKeystore(type, password), file, password); + } +} diff --git a/vespajlib/src/test/java/com/yahoo/security/X509CertificateBuilderTest.java b/vespajlib/src/test/java/com/yahoo/security/X509CertificateBuilderTest.java new file mode 100644 index 00000000000..7e6d343b570 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/X509CertificateBuilderTest.java @@ -0,0 +1,83 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.cert.X509Certificate; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collection; + +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +@RunWith(Parameterized.class) +public class X509CertificateBuilderTest { + + @Parameterized.Parameters(name = "{0}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {KeyAlgorithm.RSA, 2048, SignatureAlgorithm.SHA512_WITH_RSA}, + {KeyAlgorithm.EC, 256, SignatureAlgorithm.SHA512_WITH_ECDSA}}); + } + + private final KeyAlgorithm keyAlgorithm; + private final int keySize; + private final SignatureAlgorithm signatureAlgorithm; + + public X509CertificateBuilderTest(KeyAlgorithm keyAlgorithm, + int keySize, + SignatureAlgorithm signatureAlgorithm) { + this.keyAlgorithm = keyAlgorithm; + this.keySize = keySize; + this.signatureAlgorithm = signatureAlgorithm; + } + + @Test + public void can_build_self_signed_certificate() { + KeyPair keyPair = KeyUtils.generateKeypair(keyAlgorithm, keySize); + X500Principal subject = new X500Principal("CN=myservice"); + X509Certificate cert = + X509CertificateBuilder.fromKeypair( + keyPair, + subject, + Instant.now(), + Instant.now().plus(1, ChronoUnit.DAYS), + signatureAlgorithm, + BigInteger.valueOf(1)) + .setBasicConstraints(true, true) + .build(); + assertEquals(subject, cert.getSubjectX500Principal()); + } + + @Test + public void can_build_certificate_from_csr() { + X500Principal subject = new X500Principal("CN=subject"); + X500Principal issuer = new X500Principal("CN=issuer"); + KeyPair csrKeypair = KeyUtils.generateKeypair(keyAlgorithm, keySize); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, csrKeypair, signatureAlgorithm).build(); + KeyPair caKeypair = KeyUtils.generateKeypair(keyAlgorithm, keySize); + X509Certificate cert = X509CertificateBuilder + .fromCsr( + csr, + issuer, + Instant.now(), + Instant.now().plus(1, ChronoUnit.DAYS), + caKeypair.getPrivate(), + signatureAlgorithm, + BigInteger.valueOf(1)) + .addSubjectAlternativeName("subject1.alt") + .addSubjectAlternativeName("subject2.alt") + .build(); + assertEquals(subject, cert.getSubjectX500Principal()); + } + +}
\ No newline at end of file diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/X509CertificateUtilsTest.java b/vespajlib/src/test/java/com/yahoo/security/X509CertificateUtilsTest.java index 4039bf36a5f..76a93028efe 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/X509CertificateUtilsTest.java +++ b/vespajlib/src/test/java/com/yahoo/security/X509CertificateUtilsTest.java @@ -1,8 +1,10 @@ -package com.yahoo.vespa.athenz.tls; +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; import org.junit.Test; import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; import java.security.KeyPair; import java.security.cert.X509Certificate; import java.time.Instant; @@ -10,7 +12,7 @@ import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.List; -import static com.yahoo.vespa.athenz.tls.SubjectAlternativeName.Type.DNS_NAME; +import static com.yahoo.security.SubjectAlternativeName.Type.DNS_NAME; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.is; @@ -23,7 +25,7 @@ import static org.junit.Assert.assertThat; public class X509CertificateUtilsTest { @Test public void can_deserialize_serialized_pem_certificate() { - KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 2048); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); X500Principal subject = new X500Principal("CN=myservice"); X509Certificate cert = TestUtils.createCertificate(keypair, subject); assertEquals(subject, cert.getSubjectX500Principal()); @@ -36,10 +38,10 @@ public class X509CertificateUtilsTest { @Test public void can_deserialize_serialized_pem_certificate_list() { - KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 2048); - X500Principal subject1 = new X500Principal("CN=myservice"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + X500Principal subject1 = new X500Principal("CN=myservice1"); X509Certificate cert1 = TestUtils.createCertificate(keypair, subject1); - X500Principal subject2 = new X500Principal("CN=myservice"); + X500Principal subject2 = new X500Principal("CN=myservice2"); X509Certificate cert2 = TestUtils.createCertificate(keypair, subject2); List<X509Certificate> certificateList = Arrays.asList(cert1, cert2); String pem = X509CertificateUtils.toPem(certificateList); @@ -51,7 +53,7 @@ public class X509CertificateUtilsTest { @Test public void can_list_subject_alternative_names() { - KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 2048); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); X500Principal subject = new X500Principal("CN=myservice"); SubjectAlternativeName san = new SubjectAlternativeName(DNS_NAME, "dns-san"); X509Certificate cert = X509CertificateBuilder @@ -60,8 +62,8 @@ public class X509CertificateUtilsTest { subject, Instant.now(), Instant.now().plus(1, ChronoUnit.DAYS), - SignatureAlgorithm.SHA256_WITH_RSA, - 1) + SignatureAlgorithm.SHA512_WITH_ECDSA, + BigInteger.valueOf(1)) .addSubjectAlternativeName(san) .build(); diff --git a/vespajlib/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java b/vespajlib/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java new file mode 100644 index 00000000000..ad80c52ae2a --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java @@ -0,0 +1,24 @@ +package com.yahoo.security.tls; + +import org.junit.Test; + +import java.nio.file.Path; +import java.nio.file.Paths; + +import static org.junit.Assert.*; + +/** + * @author bjorncs + */ +public class TransportSecurityOptionsTest { + + private static final Path TEST_CONFIG_FILE = Paths.get("src/test/resources/transport-security-options.json"); + + @Test + public void can_read_options_from_json_file() { + TransportSecurityOptions expectedOptions = new TransportSecurityOptions("myhost.key", "certs.pem", "my_cas.pem"); + TransportSecurityOptions actualOptions = TransportSecurityOptions.fromJsonFile(TEST_CONFIG_FILE); + assertEquals(expectedOptions, actualOptions); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/resources/transport-security-options.json b/vespajlib/src/test/resources/transport-security-options.json new file mode 100644 index 00000000000..0506c130722 --- /dev/null +++ b/vespajlib/src/test/resources/transport-security-options.json @@ -0,0 +1,7 @@ +{ + "files": { + "private-key": "myhost.key", + "ca-certificates": "my_cas.pem", + "certificates": "certs.pem" + } +}
\ No newline at end of file diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index 33553da9422..fb3b08b325f 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -33,6 +33,7 @@ vespa_define_module( src/tests/data/memory_input src/tests/data/output_writer src/tests/data/simple_buffer + src/tests/data/smart_buffer src/tests/delegatelist src/tests/dotproduct src/tests/dual_merge_director @@ -56,6 +57,8 @@ vespa_define_module( src/tests/net/send_fd src/tests/net/socket src/tests/net/socket_spec + src/tests/net/tls/openssl_impl + src/tests/net/tls/transport_options src/tests/objects/nbostream src/tests/optimized src/tests/printable @@ -118,6 +121,8 @@ vespa_define_module( src/vespa/vespalib/io src/vespa/vespalib/locale src/vespa/vespalib/net + src/vespa/vespalib/net/tls + src/vespa/vespalib/net/tls/impl src/vespa/vespalib/objects src/vespa/vespalib/stllike src/vespa/vespalib/test diff --git a/vespalib/src/tests/alloc/alloc_test.cpp b/vespalib/src/tests/alloc/alloc_test.cpp index 0e52d06a2d5..dd4adfc2fa1 100644 --- a/vespalib/src/tests/alloc/alloc_test.cpp +++ b/vespalib/src/tests/alloc/alloc_test.cpp @@ -179,6 +179,7 @@ TEST("auto alloced mmap alloc can not be extended if no room") { } TEST("mmap alloc can be extended if room") { + Alloc dummy = Alloc::allocMMap(100); Alloc reserved = Alloc::allocMMap(100); Alloc buf = Alloc::allocMMap(100); @@ -187,6 +188,7 @@ TEST("mmap alloc can be extended if room") { } TEST("mmap alloc can not be extended if no room") { + Alloc dummy = Alloc::allocMMap(100); Alloc reserved = Alloc::allocMMap(100); Alloc buf = Alloc::allocMMap(100); diff --git a/vespalib/src/tests/data/smart_buffer/CMakeLists.txt b/vespalib/src/tests/data/smart_buffer/CMakeLists.txt new file mode 100644 index 00000000000..e7468f4f508 --- /dev/null +++ b/vespalib/src/tests/data/smart_buffer/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_smart_buffer_test_app TEST + SOURCES + smart_buffer_test.cpp + DEPENDS + vespalib +) +vespa_add_test(NAME vespalib_smart_buffer_test_app COMMAND vespalib_smart_buffer_test_app) diff --git a/vespalib/src/tests/data/smart_buffer/smart_buffer_test.cpp b/vespalib/src/tests/data/smart_buffer/smart_buffer_test.cpp new file mode 100644 index 00000000000..360afba091a --- /dev/null +++ b/vespalib/src/tests/data/smart_buffer/smart_buffer_test.cpp @@ -0,0 +1,133 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/data/smart_buffer.h> + +using namespace vespalib; + +void checkMemory(const vespalib::string &expect, const Memory &mem) { + EXPECT_EQUAL(expect, vespalib::string(mem.data, mem.size)); +} + +void checkBuffer(const vespalib::string &expect, SmartBuffer &buf) { + TEST_DO(checkMemory(expect, buf.obtain())); +} + +void write_buf(const vespalib::string &str, SmartBuffer &buf) { + WritableMemory mem = buf.reserve(str.size()); + for (size_t i = 0; i < str.size(); ++i) { + mem.data[i] = str.data()[i]; + } + buf.commit(str.size()); +} + +TEST("require that basic read/write works") { + SmartBuffer buf(3); + TEST_DO(checkBuffer("", buf)); + { // read from empty buffer + EXPECT_EQUAL(0u, buf.obtain().size); + } + { // write to buffer + WritableMemory mem = buf.reserve(10); + TEST_DO(checkBuffer("", buf)); + EXPECT_LESS_EQUAL(10u, mem.size); + mem.data[0] = 'a'; + mem.data[1] = 'b'; + mem.data[2] = 'c'; + EXPECT_EQUAL(&buf, &buf.commit(3)); + mem = buf.reserve(0); + TEST_DO(checkBuffer("abc", buf)); + EXPECT_LESS_EQUAL(0u, mem.size); + } + { // read without evicting last byte + Memory mem = buf.obtain(); + TEST_DO(checkBuffer("abc", buf)); + TEST_DO(checkMemory("abc", mem)); + EXPECT_EQUAL(&buf, &buf.evict(2)); + mem = buf.obtain(); + TEST_DO(checkBuffer("c", buf)); + TEST_DO(checkMemory("c", mem)); + mem = buf.obtain(); + TEST_DO(checkBuffer("c", buf)); + TEST_DO(checkMemory("c", mem)); + } + { // write more to buffer + WritableMemory mem = buf.reserve(10); + EXPECT_LESS_EQUAL(10u, mem.size); + TEST_DO(checkBuffer("c", buf)); + mem.data[0] = 'd'; + EXPECT_EQUAL(&buf, &buf.commit(1)); + mem = buf.reserve(5); + TEST_DO(checkBuffer("cd", buf)); + EXPECT_LESS_EQUAL(5u, mem.size); + } + { // read until end + Memory mem = buf.obtain(); + TEST_DO(checkBuffer("cd", buf)); + TEST_DO(checkMemory("cd", mem)); + EXPECT_EQUAL(&buf, &buf.evict(1)); + mem = buf.obtain(); + TEST_DO(checkBuffer("d", buf)); + TEST_DO(checkMemory("d", mem)); + EXPECT_EQUAL(&buf, &buf.evict(1)); + mem = buf.obtain(); + TEST_DO(checkBuffer("", buf)); + TEST_DO(checkMemory("", mem)); + } +} + +TEST("require that requested initial size is not adjusted") { + SmartBuffer buf(400); + EXPECT_EQUAL(buf.capacity(), 400u); +} + +TEST("require that buffer auto-resets when empty") { + SmartBuffer buf(64); + EXPECT_EQUAL(buf.reserve(10).size, 64u); + write_buf("abc", buf); + EXPECT_EQUAL(buf.reserve(10).size, 61u); + buf.evict(3); + EXPECT_EQUAL(buf.reserve(10).size, 64u); +} + +TEST("require that buffer can grow") { + SmartBuffer buf(64); + EXPECT_EQUAL(buf.capacity(), 64u); + write_buf("abc", buf); + write_buf("abc", buf); + buf.evict(3); + EXPECT_EQUAL(buf.reserve(70).size, size_t(128 - 3)); + TEST_DO(checkBuffer("abc", buf)); + EXPECT_EQUAL(buf.capacity(), 128u); +} + +TEST("require that buffer can grow more than 2x") { + SmartBuffer buf(64); + EXPECT_EQUAL(buf.capacity(), 64u); + write_buf("abc", buf); + write_buf("abc", buf); + buf.evict(3); + EXPECT_EQUAL(buf.reserve(170).size, 170u); + TEST_DO(checkBuffer("abc", buf)); + EXPECT_EQUAL(buf.capacity(), 173u); +} + +TEST("require that buffer can be compacted") { + SmartBuffer buf(16); + EXPECT_EQUAL(buf.capacity(), 16u); + write_buf("abc", buf); + write_buf("abc", buf); + buf.evict(3); + write_buf("abc", buf); + buf.evict(3); + write_buf("abc", buf); + buf.evict(3); + write_buf("abc", buf); + buf.evict(3); + EXPECT_EQUAL(buf.reserve(0).size, 1u); + write_buf("abc", buf); + TEST_DO(checkBuffer("abcabc", buf)); + EXPECT_EQUAL(buf.capacity(), 16u); + EXPECT_EQUAL(buf.reserve(0).size, 10u); +} + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/tests/io/fileutil/fileutiltest.cpp b/vespalib/src/tests/io/fileutil/fileutiltest.cpp index 56ee83f697a..8345a8dcb99 100644 --- a/vespalib/src/tests/io/fileutil/fileutiltest.cpp +++ b/vespalib/src/tests/io/fileutil/fileutiltest.cpp @@ -405,7 +405,7 @@ TEST("require that vespalib::copy works") f.write(buffer.get(), 4096, 0); f.close(); std::cerr << "Simple copy\n"; - // Simple copy works (512b dividable file) + // Simple copy works (4096b dividable file) copy("myfile", "targetfile"); ASSERT_TRUE(system("diff myfile targetfile") == 0); std::cerr << "Overwriting\n"; diff --git a/vespalib/src/tests/net/tls/openssl_impl/CMakeLists.txt b/vespalib/src/tests/net/tls/openssl_impl/CMakeLists.txt new file mode 100644 index 00000000000..799e2291d7c --- /dev/null +++ b/vespalib/src/tests/net/tls/openssl_impl/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_net_tls_openssl_impl_test_app TEST + SOURCES + openssl_impl_test.cpp + DEPENDS + vespalib +) +vespa_add_test(NAME vespalib_net_tls_openssl_impl_test_app COMMAND vespalib_net_tls_openssl_impl_test_app) + diff --git a/vespalib/src/tests/net/tls/openssl_impl/openssl_impl_test.cpp b/vespalib/src/tests/net/tls/openssl_impl/openssl_impl_test.cpp new file mode 100644 index 00000000000..4e8bf31e75e --- /dev/null +++ b/vespalib/src/tests/net/tls/openssl_impl/openssl_impl_test.cpp @@ -0,0 +1,134 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/net/tls/tls_context.h> +#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <vespa/vespalib/net/tls/crypto_codec.h> +#include <vespa/vespalib/test/make_tls_options_for_testing.h> +#include <iostream> +#include <stdlib.h> + +using namespace vespalib; +using namespace vespalib::net::tls; + +const char* decode_state_to_str(DecodeResult::State state) noexcept { + switch (state) { + case DecodeResult::State::Failed: return "Broken"; + case DecodeResult::State::OK: return "OK"; + case DecodeResult::State::NeedsMorePeerData: return "NeedsMorePeerData"; + default: + abort(); + } +} + +const char* hs_state_to_str(HandshakeResult::State state) noexcept { + switch (state) { + case HandshakeResult::State::Failed: return "Broken"; + case HandshakeResult::State::Done: return "Done"; + case HandshakeResult::State::NeedsMorePeerData: return "NeedsMorePeerData"; + default: + abort(); + } +} + +void log_handshake_result(const char* mode, const HandshakeResult& res) { + fprintf(stderr, "(handshake) %s consumed %zu peer bytes, wrote %zu peer bytes. State: %s\n", + mode, res.bytes_consumed, res.bytes_produced, + hs_state_to_str(res.state)); +} + +void log_encode_result(const char* mode, const EncodeResult& res) { + fprintf(stderr, "(encode) %s read %zu plaintext, wrote %zu cipher. State: %s\n", + mode, res.bytes_consumed, res.bytes_produced, + res.failed ? "Broken! D:" : "OK"); +} + +void log_decode_result(const char* mode, const DecodeResult& res) { + fprintf(stderr, "(decode) %s read %zu cipher, wrote %zu plaintext. State: %s\n", + mode, res.bytes_consumed, res.bytes_produced, + decode_state_to_str(res.state)); +} + +bool complete_handshake(CryptoCodec& client, CryptoCodec& server) { + // Not using vespalib::string here since it doesn't have erase(iter, length) implemented. + std::string client_to_server_buf; + std::string server_to_client_buf; + + HandshakeResult cli_res; + HandshakeResult serv_res; + while (!(cli_res.done() && serv_res.done())) { + client_to_server_buf.resize(client.min_encode_buffer_size()); + server_to_client_buf.resize(server.min_encode_buffer_size()); + + cli_res = client.handshake(server_to_client_buf.data(), serv_res.bytes_produced, + client_to_server_buf.data(), client_to_server_buf.size()); + log_handshake_result("client", cli_res); + server_to_client_buf.erase(server_to_client_buf.begin(), server_to_client_buf.begin() + cli_res.bytes_consumed); + + serv_res = server.handshake(client_to_server_buf.data(), cli_res.bytes_produced, + server_to_client_buf.data(), server_to_client_buf.size()); + log_handshake_result("server", serv_res); + client_to_server_buf.erase(client_to_server_buf.begin(), client_to_server_buf.begin() + serv_res.bytes_consumed); + + if (cli_res.failed() || serv_res.failed()) { + return false; + } + } + return true; +} + +TEST("client and server can complete handshake") { + // TODO move to fixture + auto tls_opts = vespalib::test::make_tls_options_for_testing(); + auto tls_ctx = TlsContext::create_default_context(tls_opts); + auto client = CryptoCodec::create_default_codec(*tls_ctx, CryptoCodec::Mode::Client); + auto server = CryptoCodec::create_default_codec(*tls_ctx, CryptoCodec::Mode::Server); + + EXPECT_TRUE(complete_handshake(*client, *server)); +} + +TEST("client can send single data frame to server after handshake") { + // TODO move to fixture + auto tls_opts = vespalib::test::make_tls_options_for_testing(); + auto tls_ctx = TlsContext::create_default_context(tls_opts); + auto client = CryptoCodec::create_default_codec(*tls_ctx, CryptoCodec::Mode::Client); + auto server = CryptoCodec::create_default_codec(*tls_ctx, CryptoCodec::Mode::Server); + + ASSERT_TRUE(complete_handshake(*client, *server)); + + std::string client_to_server_buf; + client_to_server_buf.resize(client->min_encode_buffer_size()); + + std::string client_plaintext = "Hellooo world! :D"; + auto cli_res = client->encode(client_plaintext.data(), client_plaintext.size(), + client_to_server_buf.data(), client_to_server_buf.size()); + log_encode_result("client", cli_res); + + std::string server_plaintext_out; + server_plaintext_out.resize(server->min_decode_buffer_size()); + auto serv_res = server->decode(client_to_server_buf.data(), cli_res.bytes_produced, + server_plaintext_out.data(), server_plaintext_out.size()); + log_decode_result("server", serv_res); + + ASSERT_FALSE(cli_res.failed); + ASSERT_FALSE(serv_res.failed()); + + ASSERT_TRUE(serv_res.state == DecodeResult::State::OK); + std::string data_received(server_plaintext_out.data(), serv_res.bytes_produced); + EXPECT_EQUAL(client_plaintext, data_received); +} + +/* + * TODO tests: + * - full duplex read/write + * - read and write of > frame size data + * - handshakes with multi frame writes + * - completed handshake with pipelined data frame + * - short ciphertext reads on decode + * - short plaintext writes on decode (.. if we even want to support this..) + * - short ciphertext write on encode + * - peer certificate validation on server + * - peer certificate validation on client + * - detection of peer shutdown session + */ + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/tests/net/tls/transport_options/CMakeLists.txt b/vespalib/src/tests/net/tls/transport_options/CMakeLists.txt new file mode 100644 index 00000000000..ee1e2477708 --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/CMakeLists.txt @@ -0,0 +1,10 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_net_tls_transport_options_test_app TEST + SOURCES + transport_options_reading_test.cpp + DEPENDS + vespalib +) +vespa_add_test(NAME vespalib_net_tls_transport_options_test_app + COMMAND vespalib_net_tls_transport_options_test_app) + diff --git a/vespalib/src/tests/net/tls/transport_options/dummy_ca_certs.txt b/vespalib/src/tests/net/tls/transport_options/dummy_ca_certs.txt new file mode 100644 index 00000000000..b617f6f17e4 --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/dummy_ca_certs.txt @@ -0,0 +1 @@ +My CA certificates diff --git a/vespalib/src/tests/net/tls/transport_options/dummy_certs.txt b/vespalib/src/tests/net/tls/transport_options/dummy_certs.txt new file mode 100644 index 00000000000..088b91ff770 --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/dummy_certs.txt @@ -0,0 +1 @@ +My certificate chain diff --git a/vespalib/src/tests/net/tls/transport_options/dummy_privkey.txt b/vespalib/src/tests/net/tls/transport_options/dummy_privkey.txt new file mode 100644 index 00000000000..f29585fe31f --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/dummy_privkey.txt @@ -0,0 +1 @@ +My private key diff --git a/vespalib/src/tests/net/tls/transport_options/ok_config.json b/vespalib/src/tests/net/tls/transport_options/ok_config.json new file mode 100644 index 00000000000..dd2591661dc --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/ok_config.json @@ -0,0 +1,7 @@ +{ + "files":{ + "private-key": "dummy_privkey.txt", + "ca-certificates": "dummy_ca_certs.txt", + "certificates": "dummy_certs.txt" + } +} diff --git a/vespalib/src/tests/net/tls/transport_options/transport_options_reading_test.cpp b/vespalib/src/tests/net/tls/transport_options/transport_options_reading_test.cpp new file mode 100644 index 00000000000..1ce4a4353d0 --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/transport_options_reading_test.cpp @@ -0,0 +1,65 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/vespalib/io/fileutil.h> +#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <vespa/vespalib/net/tls/transport_security_options_reading.h> +#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/util/exceptions.h> + +using namespace vespalib; +using namespace vespalib::net::tls; + +TEST("can load TLS credentials via config file") { + auto opts = read_options_from_json_file("ok_config.json"); + ASSERT_TRUE(opts.get() != nullptr); + // Obviously we'd need to change this to actual PEM data if config reading started + // actually verifying the _content_ of files, not just reading them. + EXPECT_EQUAL("My private key\n", opts->private_key_pem()); + EXPECT_EQUAL("My CA certificates\n", opts->ca_certs_pem()); + EXPECT_EQUAL("My certificate chain\n", opts->cert_chain_pem()); +} + +TEST("missing JSON file throws exception") { + EXPECT_EXCEPTION(read_options_from_json_file("missing_config.json"), IllegalArgumentException, + "TLS config file 'missing_config.json' could not be read"); +} + +TEST("bad JSON content throws exception") { + const char* bad_json = "hello world :D"; + EXPECT_EXCEPTION(read_options_from_json_string(bad_json), IllegalArgumentException, + "Provided TLS config file is not valid JSON"); +} + +TEST("missing 'files' field throws exception") { + const char* incomplete_json = R"({})"; + EXPECT_EXCEPTION(read_options_from_json_string(incomplete_json), IllegalArgumentException, + "TLS config root field 'files' is missing or empty"); +} + +TEST("missing 'private-key' field throws exception") { + const char* incomplete_json = R"({"files":{"certificates":"dummy_certs.txt","ca-certificates":"dummy_ca_certs.txt"}})"; + EXPECT_EXCEPTION(read_options_from_json_string(incomplete_json), IllegalArgumentException, + "TLS config field 'private-key' has not been set"); +} + +TEST("missing 'certificates' field throws exception") { + const char* incomplete_json = R"({"files":{"private-key":"dummy_privkey.txt","ca-certificates":"dummy_ca_certs.txt"}})"; + EXPECT_EXCEPTION(read_options_from_json_string(incomplete_json), IllegalArgumentException, + "TLS config field 'certificates' has not been set"); +} + +TEST("missing 'ca-certificates' field throws exception") { + const char* incomplete_json = R"({"files":{"private-key":"dummy_privkey.txt","certificates":"dummy_certs.txt"}})"; + EXPECT_EXCEPTION(read_options_from_json_string(incomplete_json), IllegalArgumentException, + "TLS config field 'ca-certificates' has not been set"); +} + +TEST("missing file referenced by field throws exception") { + const char* incomplete_json = R"({"files":{"private-key":"missing_privkey.txt", + "certificates":"dummy_certs.txt", + "ca-certificates":"dummy_ca_certs.txt"}})"; + EXPECT_EXCEPTION(read_options_from_json_string(incomplete_json), IllegalArgumentException, + "File 'missing_privkey.txt' referenced by TLS config does not exist"); +} + +TEST_MAIN() { TEST_RUN_ALL(); } + diff --git a/vespalib/src/vespa/vespalib/CMakeLists.txt b/vespalib/src/vespa/vespalib/CMakeLists.txt index 480caf8f28d..8261bb8874e 100644 --- a/vespalib/src/vespa/vespalib/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/CMakeLists.txt @@ -9,8 +9,11 @@ vespa_add_library(vespalib $<TARGET_OBJECTS:vespalib_vespalib_io> $<TARGET_OBJECTS:vespalib_vespalib_locale> $<TARGET_OBJECTS:vespalib_vespalib_net> + $<TARGET_OBJECTS:vespalib_vespalib_net_tls> + $<TARGET_OBJECTS:vespalib_vespalib_net_tls_impl> $<TARGET_OBJECTS:vespalib_vespalib_objects> $<TARGET_OBJECTS:vespalib_vespalib_stllike> + $<TARGET_OBJECTS:vespalib_vespalib_test> $<TARGET_OBJECTS:vespalib_vespalib_testkit> $<TARGET_OBJECTS:vespalib_vespalib_text> $<TARGET_OBJECTS:vespalib_vespalib_time> @@ -20,6 +23,7 @@ vespa_add_library(vespalib $<TARGET_OBJECTS:vespalib_vespalib_xxhash> INSTALL lib64 DEPENDS - vespalib_vespalib_test gcc ) + +vespa_add_target_package_dependency(vespalib OpenSSL) diff --git a/vespalib/src/vespa/vespalib/data/CMakeLists.txt b/vespalib/src/vespa/vespalib/data/CMakeLists.txt index 3a94e00ae33..517d0cd198f 100644 --- a/vespalib/src/vespa/vespalib/data/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/data/CMakeLists.txt @@ -12,6 +12,7 @@ vespa_add_library(vespalib_vespalib_data OBJECT output.cpp output_writer.cpp simple_buffer.cpp + smart_buffer.cpp writable_memory.cpp DEPENDS ) diff --git a/vespalib/src/vespa/vespalib/data/smart_buffer.cpp b/vespalib/src/vespa/vespalib/data/smart_buffer.cpp new file mode 100644 index 00000000000..401b6729601 --- /dev/null +++ b/vespalib/src/vespa/vespalib/data/smart_buffer.cpp @@ -0,0 +1,68 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "smart_buffer.h" +#include <cassert> + +namespace vespalib { + +void +SmartBuffer::ensure_free(size_t bytes) +{ + if (write_len() >= bytes) { + return; + } + if ((unused() < bytes) || ((unused() * 3) < read_len())) { + size_t new_size = std::max(_data.size() * 2, read_len() + bytes); + alloc::Alloc new_buf(alloc::Alloc::alloc(new_size)); + memcpy(new_buf.get(), read_ptr(), read_len()); + _data.swap(new_buf); + } else { + memmove(_data.get(), read_ptr(), read_len()); + } + _write_pos = read_len(); + _read_pos = 0; +} + +SmartBuffer::SmartBuffer(size_t initial_size) + : _data(alloc::Alloc::alloc(initial_size)), + _read_pos(0), + _write_pos(0) +{ +} + +SmartBuffer::~SmartBuffer() = default; + +Memory +SmartBuffer::obtain() +{ + return Memory(read_ptr(), read_len()); +} + +Input & +SmartBuffer::evict(size_t bytes) +{ + assert(read_len() >= bytes); + _read_pos += bytes; + if (_read_pos == _write_pos) { + _read_pos = 0; + _write_pos = 0; + } + return *this; +} + +WritableMemory +SmartBuffer::reserve(size_t bytes) +{ + ensure_free(bytes); + return WritableMemory(write_ptr(), write_len()); +} + +Output & +SmartBuffer::commit(size_t bytes) +{ + assert(write_len() >= bytes); + _write_pos += bytes; + return *this; +} + +} // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/data/smart_buffer.h b/vespalib/src/vespa/vespalib/data/smart_buffer.h new file mode 100644 index 00000000000..f7c4dd05c3e --- /dev/null +++ b/vespalib/src/vespa/vespalib/data/smart_buffer.h @@ -0,0 +1,41 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "input.h" +#include "output.h" +#include <vespa/vespalib/util/alloc.h> + +namespace vespalib { + +/** + * A somewhat smarter buffer compared to SimpleBuffer. Keeps track of + * data in a continuous memory segment. Tries to limit copying of + * data. + **/ +class SmartBuffer : public Input, + public Output +{ +private: + alloc::Alloc _data; + size_t _read_pos; + size_t _write_pos; + + const char *read_ptr() const { return (const char *)(_data.get()) + _read_pos; } + size_t read_len() const { return (_write_pos - _read_pos); } + char *write_ptr() { return (char *)(_data.get()) + _write_pos; } + size_t write_len() const { return (_data.size() - _write_pos); } + size_t unused() const { return (_data.size() - read_len()); } + void ensure_free(size_t bytes); + +public: + SmartBuffer(size_t initial_size); + ~SmartBuffer(); + size_t capacity() const { return _data.size(); } + Memory obtain() override; + Input &evict(size_t bytes) override; + WritableMemory reserve(size_t bytes) override; + Output &commit(size_t bytes) override; +}; + +} // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/io/fileutil.cpp b/vespalib/src/vespa/vespalib/io/fileutil.cpp index 5ab5fb99a0d..5acde54047a 100644 --- a/vespalib/src/vespa/vespalib/io/fileutil.cpp +++ b/vespalib/src/vespa/vespalib/io/fileutil.cpp @@ -699,7 +699,7 @@ rename(const string & frompath, const string & topath, namespace { uint32_t bufferSize = 1024 * 1024; - uint32_t diskAlignmentSize = 512; + uint32_t diskAlignmentSize = 4096; } diff --git a/vespalib/src/vespa/vespalib/net/crypto_engine.cpp b/vespalib/src/vespa/vespalib/net/crypto_engine.cpp index 8832b4b1cfe..ec225311b60 100644 --- a/vespalib/src/vespa/vespalib/net/crypto_engine.cpp +++ b/vespalib/src/vespa/vespalib/net/crypto_engine.cpp @@ -5,6 +5,11 @@ #include <chrono> #include <thread> #include <vespa/vespalib/xxhash/xxhash.h> +#include <vespa/vespalib/stllike/string.h> +#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <vespa/vespalib/net/tls/transport_security_options_reading.h> +#include <vespa/vespalib/net/tls/tls_crypto_engine.h> +#include <vespa/vespalib/data/smart_buffer.h> #include <assert.h> namespace vespalib { @@ -42,14 +47,14 @@ public: class XorCryptoSocket : public CryptoSocket { private: - static constexpr size_t CHUNK_SIZE = 4096; + static constexpr size_t CHUNK_SIZE = 16 * 1024; enum class OP { READ_KEY, WRITE_KEY }; std::vector<OP> _op_stack; - char _my_key; - char _peer_key; - std::vector<char> _readbuf; - std::vector<char> _writebuf; - SocketHandle _socket; + char _my_key; + char _peer_key; + SmartBuffer _input; + SmartBuffer _output; + SocketHandle _socket; bool is_blocked(ssize_t res, int error) const { return ((res < 0) && ((error == EWOULDBLOCK) || (error == EAGAIN))); @@ -91,8 +96,8 @@ public: : std::vector<OP>({OP::READ_KEY, OP::WRITE_KEY})), _my_key(gen_key()), _peer_key(0), - _readbuf(), - _writebuf(), + _input(CHUNK_SIZE * 2), + _output(CHUNK_SIZE * 2), _socket(std::move(socket)) {} int get_fd() const override { return _socket.get(); } HandshakeResult handshake() override { @@ -107,58 +112,68 @@ public: } size_t min_read_buffer_size() const override { return 1; } ssize_t read(char *buf, size_t len) override { - if (_readbuf.empty()) { - _readbuf.resize(CHUNK_SIZE); - ssize_t res = _socket.read(&_readbuf[0], _readbuf.size()); + if (_input.obtain().size < CHUNK_SIZE) { + auto dst = _input.reserve(CHUNK_SIZE); + ssize_t res = _socket.read(dst.data, dst.size); if (res > 0) { - _readbuf.resize(res); + _input.commit(res); } else { - _readbuf.clear(); - return res; + return res; // eof/error } } return drain(buf, len); } ssize_t drain(char *buf, size_t len) override { - size_t frame = std::min(len, _readbuf.size()); + auto src = _input.obtain(); + size_t frame = std::min(len, src.size); for (size_t i = 0; i < frame; ++i) { - buf[i] = (_readbuf[i] ^ _my_key); + buf[i] = (src.data[i] ^ _my_key); } - _readbuf.erase(_readbuf.begin(), _readbuf.begin() + frame); + _input.evict(frame); return frame; } ssize_t write(const char *buf, size_t len) override { - ssize_t res = flush(); - while (res > 0) { - res = flush(); - } - if (res < 0) { - return res; + if (_output.obtain().size >= CHUNK_SIZE) { + if (flush() < 0) { + return -1; + } + if (_output.obtain().size > 0) { + errno = EWOULDBLOCK; + return -1; + } } size_t frame = std::min(len, CHUNK_SIZE); + auto dst = _output.reserve(frame); for (size_t i = 0; i < frame; ++i) { - _writebuf.push_back(buf[i] ^ _peer_key); + dst.data[i] = (buf[i] ^ _peer_key); } + _output.commit(frame); return frame; } ssize_t flush() override { - if (!_writebuf.empty()) { - ssize_t res = _socket.write(&_writebuf[0], _writebuf.size()); + auto pending = _output.obtain(); + if (pending.size > 0) { + ssize_t res = _socket.write(pending.data, pending.size); if (res > 0) { - _writebuf.erase(_writebuf.begin(), _writebuf.begin() + res); + _output.evict(res); + return 1; // progress } else { assert(res < 0); + return -1; // error } - return res; } - return 0; + return 0; // done } }; CryptoEngine::SP create_default_crypto_engine() { - // TODO: check VESPA_TLS_CONFIG_FILE here - // return std::make_shared<XorCryptoEngine>(); - return std::make_shared<NullCryptoEngine>(); + const char *env = getenv("VESPA_TLS_CONFIG_FILE"); + vespalib::string cfg_file = env ? env : ""; + if (cfg_file.empty()) { + return std::make_shared<NullCryptoEngine>(); + } + auto tls_opts = net::tls::read_options_from_json_file(cfg_file); + return std::make_shared<TlsCryptoEngine>(*tls_opts); } } // namespace vespalib::<unnamed> diff --git a/vespalib/src/vespa/vespalib/net/crypto_socket.h b/vespalib/src/vespa/vespalib/net/crypto_socket.h index 7fe7871960f..f78f7fc0ce7 100644 --- a/vespalib/src/vespa/vespalib/net/crypto_socket.h +++ b/vespalib/src/vespa/vespalib/net/crypto_socket.h @@ -74,13 +74,16 @@ struct CryptoSocket { virtual ssize_t write(const char *buf, size_t len) = 0; /** - * Try to flush data in the write pipeline that is not depenedent + * Try to flush data in the write pipeline that is not dependent * on data not yet written by the application into the underlying * socket. This is to enable the application to identify pending * work that may not be completed until the underlying socket is * ready for writing more data. The semantics are the same as with * a normal socket write (errno, etc.) with the exception that 0 - * will be returned when there is no more data to flush. + * will be returned when there is no more data to flush and any + * positive number indicates that we were able to flush something + * (it does not need to reflect the actual number of bytes written + * to the underlying socket). **/ virtual ssize_t flush() = 0; diff --git a/vespalib/src/vespa/vespalib/net/tls/CMakeLists.txt b/vespalib/src/vespa/vespalib/net/tls/CMakeLists.txt new file mode 100644 index 00000000000..2d34a3e1c80 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_library(vespalib_vespalib_net_tls OBJECT + SOURCES + crypto_codec.cpp + crypto_codec_adapter.cpp + crypto_exception.cpp + tls_context.cpp + tls_crypto_engine.cpp + transport_security_options.cpp + transport_security_options_reading.cpp + DEPENDS +) +find_package(OpenSSL) +target_include_directories(vespalib_vespalib_net_tls PUBLIC ${OPENSSL_INCLUDE_DIR}) + diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_codec.cpp b/vespalib/src/vespa/vespalib/net/tls/crypto_codec.cpp new file mode 100644 index 00000000000..b36913d20e3 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_codec.cpp @@ -0,0 +1,15 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "crypto_codec.h" +#include <vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h> +#include <vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h> +#include <cassert> + +namespace vespalib::net::tls { + +std::unique_ptr<CryptoCodec> CryptoCodec::create_default_codec(TlsContext& ctx, Mode mode) { + auto* ssl_ctx = dynamic_cast<impl::OpenSslTlsContextImpl*>(&ctx); + assert(ssl_ctx != nullptr); + return std::make_unique<impl::OpenSslCryptoCodecImpl>(*ssl_ctx->native_context(), mode); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_codec.h b/vespalib/src/vespa/vespalib/net/tls/crypto_codec.h new file mode 100644 index 00000000000..6e690c809a5 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_codec.h @@ -0,0 +1,124 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <memory> + +namespace vespalib::net::tls { + +struct HandshakeResult { + // Handshake bytes consumed from peer. + size_t bytes_consumed = 0; + // Handshake bytes produced that must be sent to the peer. + size_t bytes_produced = 0; + enum class State { + Failed, + Done, + NeedsMorePeerData + }; + State state = State::Failed; + + bool failed() const noexcept { return (state == State::Failed); } + bool done() const noexcept { return (state == State::Done); } +}; + +struct EncodeResult { + // Plaintext bytes consumed + size_t bytes_consumed = 0; + // Ciphertext bytes produced that must be sent to the peer + size_t bytes_produced = 0; + bool failed = true; +}; + +struct DecodeResult { + // Ciphertext bytes consumed from peer + size_t bytes_consumed = 0; + // Plaintext bytes produced. + size_t bytes_produced = 0; + enum class State { + Failed, + OK, + NeedsMorePeerData + // TODO add Closed/Shutdown as own state? + }; + State state = State::Failed; + + bool failed() const noexcept { return (state == State::Failed); } +}; + +class TlsContext; + +// TODO move to different namespace, not dependent on TLS? + +/* + * A CryptoCodec provides a fully transport-independent way of negotiating + * a secure, authenticated session towards another peer. The codec requires + * the caller to handle any and all actual data transfer + */ +class CryptoCodec { +public: + enum class Mode { + Client, Server + }; + + virtual ~CryptoCodec() = default; + + /* + * Minimum buffer size required to represent one wire format frame + * of encrypted (ciphertext) data, including frame overhead. + */ + virtual size_t min_encode_buffer_size() const noexcept = 0; + /* + * Minimum buffer size required to represent the decoded (plaintext) + * output of a single frame of encrypted data. + */ + virtual size_t min_decode_buffer_size() const noexcept = 0; + + /* + * Precondition: to_peer_buf_size >= min_encode_buffer_size() + * Postcondition: if result.done(), the handshake process has completed + * and data may be passed through encode()/decode(). + */ + virtual HandshakeResult handshake(const char* from_peer, size_t from_peer_buf_size, + char* to_peer, size_t to_peer_buf_size) noexcept = 0; + + /* + * Encodes a single ciphertext frame into `ciphertext`. If plaintext_size + * is greater than can fit into a frame, the returned result's consumed_bytes + * field will be < plaintext_size. The number of actual ciphertext bytes produced + * is available in the returned result's produced_bytes field. + * + * Precondition: handshake must be completed + * Precondition: ciphertext_size >= min_encode_buffer_size(), i.e. it must be + * possible to encode at least 1 frame. + * Postcondition: if plaintext_size > 0 and result.failed == false, a single + * frame of ciphertext has been written into the to_peer buffer. + * Size of written frame is given by result.bytes_produced. This + * includes all protocol-specific frame overhead. + */ + virtual EncodeResult encode(const char* plaintext, size_t plaintext_size, + char* ciphertext, size_t ciphertext_size) noexcept = 0; + /* + * Attempt to decode ciphertext sent by the peer into plaintext. Since + * ciphertext is sent in frames, it's possible that invoking decode() + * may produce a CodecResult with a state of `NeedsMorePeerData` if a + * complete frame is not present in `ciphertext`. In this case, decode() + * must be called again once more data is available. + * + * Precondition: handshake must be completed + * Precondition: plaintext_size >= min_decode_buffer_size() + * Postcondition: if result.state == DecodeResult::State::OK, at least 1 + * complete frame has been written to the `plaintext` buffer + */ + virtual DecodeResult decode(const char* ciphertext, size_t ciphertext_size, + char* plaintext, size_t plaintext_size) noexcept = 0; + + /* + * Creates an implementation defined CryptoCodec that provides at least TLSv1.2 + * compliant handshaking and full duplex data transfer. + * + * Throws CryptoException if resources cannot be allocated for the codec. + */ + static std::unique_ptr<CryptoCodec> create_default_codec(TlsContext& ctx, Mode mode); +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.cpp b/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.cpp new file mode 100644 index 00000000000..f8fa4ab2b53 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.cpp @@ -0,0 +1,152 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "crypto_codec_adapter.h" +#include <assert.h> + +namespace vespalib::net::tls { + +CryptoSocket::HandshakeResult +CryptoCodecAdapter::hs_try_flush() +{ + auto flush_res = flush_all(); + if (flush_res == 0) { + return HandshakeResult::DONE; + } else if (is_blocked(flush_res, errno)) { + return HandshakeResult::NEED_WRITE; + } else { + return HandshakeResult::FAIL; + } +} + +CryptoSocket::HandshakeResult +CryptoCodecAdapter::hs_try_fill() +{ + auto fill_res = fill_input(); + if (fill_res > 0) { + return HandshakeResult::DONE; + } else if (is_blocked(fill_res, errno)) { + return HandshakeResult::NEED_READ; + } else { // eof included here + return HandshakeResult::FAIL; + } +} + +ssize_t +CryptoCodecAdapter::fill_input() +{ + if (_input.obtain().size < _codec->min_encode_buffer_size()) { + auto dst = _input.reserve(_codec->min_encode_buffer_size()); + ssize_t res = _socket.read(dst.data, dst.size); + if (res > 0) { + _input.commit(res); + } else { + return res; // eof/error + } + } + return 1; // progress +} + +ssize_t +CryptoCodecAdapter::flush_all() +{ + ssize_t res = flush(); + while (res > 0) { + res = flush(); + } + return res; +} + +CryptoSocket::HandshakeResult +CryptoCodecAdapter::handshake() +{ + for (;;) { + auto in = _input.obtain(); + auto out = _output.reserve(_codec->min_encode_buffer_size()); + auto hs_res = _codec->handshake(in.data, in.size, out.data, out.size); + _input.evict(hs_res.bytes_consumed); + _output.commit(hs_res.bytes_produced); + switch (hs_res.state) { + case ::vespalib::net::tls::HandshakeResult::State::Failed: return HandshakeResult::FAIL; + case ::vespalib::net::tls::HandshakeResult::State::Done: return hs_try_flush(); + case ::vespalib::net::tls::HandshakeResult::State::NeedsMorePeerData: + auto flush_res = hs_try_flush(); + if (flush_res != HandshakeResult::DONE) { + return flush_res; + } + auto fill_res = hs_try_fill(); + if (fill_res != HandshakeResult::DONE) { + return fill_res; + } + } + } + return HandshakeResult::DONE; +} + +ssize_t +CryptoCodecAdapter::read(char *buf, size_t len) +{ + auto fill_res = fill_input(); + if (fill_res <= 0) { + return fill_res; + } + auto drain_res = drain(buf, len); + if (drain_res != 0) { + return drain_res; + } + errno = EWOULDBLOCK; + return -1; +} + +ssize_t +CryptoCodecAdapter::drain(char *buf, size_t len) +{ + auto src = _input.obtain(); + auto res = _codec->decode(src.data, src.size, buf, len); + if (res.failed()) { + errno = EIO; + return -1; + } + _input.evict(res.bytes_consumed); + return res.bytes_produced; +} + +ssize_t +CryptoCodecAdapter::write(const char *buf, size_t len) +{ + if (_output.obtain().size >= _codec->min_encode_buffer_size()) { + if (flush() < 0) { + return -1; + } + if (_output.obtain().size > 0) { + errno = EWOULDBLOCK; + return -1; + } + } + auto dst = _output.reserve(_codec->min_encode_buffer_size()); + auto res = _codec->encode(buf, len, dst.data, dst.size); + if (res.failed) { + errno = EIO; + return -1; + } + _output.commit(res.bytes_produced); + return res.bytes_consumed; +} + +ssize_t +CryptoCodecAdapter::flush() +{ + auto pending = _output.obtain(); + if (pending.size > 0) { + ssize_t res = _socket.write(pending.data, pending.size); + if (res > 0) { + _output.evict(res); + return 1; // progress + } else { + assert(res < 0); + return -1; // error + } + } + return 0; // done +} + +} // namespace vespalib::net::tls diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.h b/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.h new file mode 100644 index 00000000000..f17693cabff --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.h @@ -0,0 +1,46 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/net/crypto_socket.h> +#include <vespa/vespalib/net/socket_handle.h> +#include <vespa/vespalib/data/smart_buffer.h> +#include "crypto_codec.h" + +namespace vespalib::net::tls { + +/** + * Component adapting an underlying CryptoCodec to the CryptoSocket + * interface by performing buffer and socket management. + * + * NOTE: initial implementation is for functionality/proof-of-concept + * purposes, not performance. + **/ +class CryptoCodecAdapter : public CryptoSocket +{ +private: + SmartBuffer _input; + SmartBuffer _output; + SocketHandle _socket; + std::unique_ptr<CryptoCodec> _codec; + + bool is_blocked(ssize_t res, int error) const { + return ((res < 0) && ((error == EWOULDBLOCK) || (error == EAGAIN))); + } + HandshakeResult hs_try_flush(); + HandshakeResult hs_try_fill(); + ssize_t fill_input(); // -1/0/1 -> error/eof/ok + ssize_t flush_all(); // -1/0 -> error/ok +public: + CryptoCodecAdapter(SocketHandle socket, std::unique_ptr<CryptoCodec> codec) + : _input(64 * 1024), _output(64 * 1024), _socket(std::move(socket)), _codec(std::move(codec)) {} + int get_fd() const override { return _socket.get(); } + HandshakeResult handshake() override; + size_t min_read_buffer_size() const override { return _codec->min_decode_buffer_size(); } + ssize_t read(char *buf, size_t len) override; + ssize_t drain(char *, size_t) override; + ssize_t write(const char *buf, size_t len) override; + ssize_t flush() override; +}; + +} // namespace vespalib::net::tls diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_exception.cpp b/vespalib/src/vespa/vespalib/net/tls/crypto_exception.cpp new file mode 100644 index 00000000000..41bb2060c04 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_exception.cpp @@ -0,0 +1,10 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "crypto_exception.h" + +namespace vespalib::net::tls { + +VESPA_IMPLEMENT_EXCEPTION(CryptoException, Exception); + +} + diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_exception.h b/vespalib/src/vespa/vespalib/net/tls/crypto_exception.h new file mode 100644 index 00000000000..696a158e058 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_exception.h @@ -0,0 +1,10 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <vespa/vespalib/util/exception.h> + +namespace vespalib::net::tls { + +VESPA_DEFINE_EXCEPTION(CryptoException, Exception); + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/CMakeLists.txt b/vespalib/src/vespa/vespalib/net/tls/impl/CMakeLists.txt new file mode 100644 index 00000000000..a5a8e8d3eb9 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/CMakeLists.txt @@ -0,0 +1,10 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_library(vespalib_vespalib_net_tls_impl OBJECT + SOURCES + openssl_tls_context_impl.cpp + openssl_crypto_codec_impl.cpp + DEPENDS +) +find_package(OpenSSL) +target_include_directories(vespalib_vespalib_net_tls_impl PUBLIC ${OPENSSL_INCLUDE_DIR}) + diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.cpp b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.cpp new file mode 100644 index 00000000000..a563a43baac --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.cpp @@ -0,0 +1,383 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "openssl_crypto_codec_impl.h" +#include "openssl_tls_context_impl.h" +#include <vespa/vespalib/net/tls/crypto_codec.h> +#include <vespa/vespalib/net/tls/crypto_exception.h> +#include <mutex> +#include <vector> +#include <memory> +#include <stdexcept> +#include <openssl/ssl.h> +#include <openssl/crypto.h> +#include <openssl/err.h> +#include <openssl/pem.h> + +#include <vespa/log/log.h> +LOG_SETUP(".vespalib.net.tls.openssl_crypto_codec_impl"); + +#if (OPENSSL_VERSION_NUMBER < 0x10000000L) +// < 1.0 requires explicit thread ID callback support. +# error "Provided OpenSSL version is too darn old, need at least 1.0" +#endif + +/* + * Beware all ye who dare enter, for this is OpenSSL integration territory. + * Dragons are known to roam the skies. Strange whispers are heard at night + * in the mist-covered lands where the forest meets the lake. Rumors of a + * tome that contains best practices and excellent documentation are heard + * at the local inn, but no one seems to know where it exists, or even if + * it ever existed. Be it best that people carry on with their lives and + * pretend to not know of the beasts that lurk beyond where the torch's + * light fades and turns to all-enveloping darkness. + */ + +namespace vespalib::net::tls::impl { + +namespace { + +bool verify_buf(const char *buf, size_t len) { + return ((len < INT32_MAX) && ((len == 0) || (buf != nullptr))); +} + +const char* ssl_error_to_str(int ssl_error) noexcept { + // From https://www.openssl.org/docs/manmaster/man3/SSL_get_error.html + // Our code paths shouldn't trigger most of these, but included for completeness + switch (ssl_error) { + case SSL_ERROR_NONE: + return "SSL_ERROR_NONE"; + case SSL_ERROR_ZERO_RETURN: + return "SSL_ERROR_ZERO_RETURN"; + case SSL_ERROR_WANT_READ: + return "SSL_ERROR_WANT_READ"; + case SSL_ERROR_WANT_WRITE: + return "SSL_ERROR_WANT_WRITE"; + case SSL_ERROR_WANT_CONNECT: + return "SSL_ERROR_WANT_CONNECT"; + case SSL_ERROR_WANT_ACCEPT: + return "SSL_ERROR_WANT_ACCEPT"; + case SSL_ERROR_WANT_X509_LOOKUP: + return "SSL_ERROR_WANT_X509_LOOKUP"; +#if (OPENSSL_VERSION_NUMBER >= 0x10100000L) + case SSL_ERROR_WANT_ASYNC: + return "SSL_ERROR_WANT_ASYNC"; + case SSL_ERROR_WANT_ASYNC_JOB: + return "SSL_ERROR_WANT_ASYNC_JOB"; +#endif +#if (OPENSSL_VERSION_NUMBER >= 0x10101000L) + case SSL_ERROR_WANT_CLIENT_HELLO_CB: + return "SSL_ERROR_WANT_CLIENT_HELLO_CB"; +#endif + case SSL_ERROR_SYSCALL: + return "SSL_ERROR_SYSCALL"; + case SSL_ERROR_SSL: + return "SSL_ERROR_SSL"; + default: + return "Unknown SSL error code"; + } +} + +HandshakeResult handshake_consumed_bytes_and_needs_more_peer_data(size_t consumed) noexcept { + return {consumed, 0, HandshakeResult::State::NeedsMorePeerData}; +} + +HandshakeResult handshake_produced_bytes_and_needs_more_peer_data(size_t produced) noexcept { + return {0, produced, HandshakeResult::State::NeedsMorePeerData}; +} + +HandshakeResult handshake_consumed_bytes_and_is_complete(size_t consumed) noexcept { + return {consumed, 0, HandshakeResult::State::Done}; +} + +HandshakeResult handshaked_bytes(size_t consumed, size_t produced, HandshakeResult::State state) noexcept { + return {consumed, produced, state}; +} + +HandshakeResult handshake_completed() noexcept { + return {0, 0, HandshakeResult::State::Done}; +} + +HandshakeResult handshake_failed() noexcept { + return {0, 0, HandshakeResult::State::Failed}; +} + +EncodeResult encode_failed() noexcept { + return {0, 0, true}; +} + +EncodeResult encoded_bytes(size_t consumed, size_t produced) noexcept { + return {consumed, produced, false}; +} + +DecodeResult decode_failed() noexcept { + return {0, 0, DecodeResult::State::Failed}; +} + +DecodeResult decoded_frames_with_plaintext_bytes(size_t produced_bytes) noexcept { + return {0, produced_bytes, DecodeResult::State::OK}; +} + +DecodeResult decode_needs_more_peer_data() noexcept { + return {0, 0, DecodeResult::State::NeedsMorePeerData}; +} + +DecodeResult decoded_bytes(size_t consumed, size_t produced, DecodeResult::State state) noexcept { + return {consumed, produced, state}; +} + +BioPtr new_tls_frame_memory_bio() { + BioPtr bio(::BIO_new(BIO_s_mem())); + if (!bio) { + throw CryptoException("IO_new(BIO_s_mem()) failed; out of memory?"); + } + BIO_set_write_buf_size(bio.get(), 0); // 0 ==> default max frame size + return bio; +} + +} // anon ns + +OpenSslCryptoCodecImpl::OpenSslCryptoCodecImpl(::SSL_CTX& ctx, Mode mode) + : _ssl(::SSL_new(&ctx)), + _mode(mode) +{ + if (!_ssl) { + throw CryptoException("Failed to create new SSL from SSL_CTX"); + } + /* + * We use two separate memory BIOs rather than a BIO pair for writing and + * reading ciphertext, respectively. This is because it _seems_ quite + * a bit more straight forward to implement a full duplex API with two + * separate BIOs, but there is little available documentation as to the + * 'hows' and 'whys' around this. + * There are claims from core OpenSSL devs[0] that BIO pairs are more efficient, + * so we may reconsider the current approach (or just use the "OpenSSL controls + * the file descriptor" yolo approach for simplicity, assuming they do optimal + * stuff internally). + * + * Our BIOs are used as follows: + * + * Handshakes may use both BIOs opaquely: + * + * handshake() : SSL_do_handshake() --(_output_bio ciphertext)--> BIO_read --> [peer] + * : SSL_do_handshake() <--(_input_bio ciphertext)-- BIO_write <-- [peer] + * + * Once handshaking is complete, the input BIO is only used for decodes and the output + * BIO is only used for encodes. We explicitly disallow TLS renegotiation, both for + * the sake of simplicity and for added security (renegotiation is a bit of a rat's nest). + * + * encode() : SSL_write(plaintext) --(_output_bio ciphertext)--> BIO_read --> [peer] + * decode() : SSL_read(plaintext) <--(_input_bio ciphertext)-- BIO_write <-- [peer] + * + * To avoid blowing the sizes of BIOs out of the water, we do our best to encode and decode + * on a per-TLS frame granularity (16K) maximum. + */ + BioPtr tmp_input_bio = new_tls_frame_memory_bio(); + BioPtr tmp_output_bio = new_tls_frame_memory_bio(); + // Connect BIOs used internally by OpenSSL. This transfers ownership. No return value to check. + // TODO replace with explicit SSL_set0_rbio/SSL_set0_wbio on OpenSSL >= v1.1 + ::SSL_set_bio(_ssl.get(), tmp_input_bio.get(), tmp_output_bio.get()); + _input_bio = tmp_input_bio.release(); + _output_bio = tmp_output_bio.release(); + if (_mode == Mode::Client) { + ::SSL_set_connect_state(_ssl.get()); + } else { + ::SSL_set_accept_state(_ssl.get()); + } +} + +// TODO remove spammy logging once code is stable + +// Produces bytes previously written to _output_bio by SSL_do_handshake or SSL_write +int OpenSslCryptoCodecImpl::drain_outgoing_network_bytes_if_any( + char *to_peer, size_t to_peer_buf_size) noexcept { + int out_pending = BIO_pending(_output_bio); + if (out_pending > 0) { + int copied = ::BIO_read(_output_bio, to_peer, static_cast<int>(to_peer_buf_size)); + // TODO BIO_should_retry here? Semantics are unclear, especially for memory BIOs. + LOG(spam, "BIO_read copied out %d bytes of ciphertext from _output_bio", copied); + if (copied < 0) { + LOG(error, "Memory BIO_read() failed with BIO_pending() > 0"); + } + return copied; + } + return out_pending; +} + +HandshakeResult OpenSslCryptoCodecImpl::handshake(const char* from_peer, size_t from_peer_buf_size, + char* to_peer, size_t to_peer_buf_size) noexcept { + LOG_ASSERT(verify_buf(from_peer, from_peer_buf_size) && verify_buf(to_peer, to_peer_buf_size)); + + if (SSL_is_init_finished(_ssl.get())) { + return handshake_completed(); + } + // Still ciphertext data left? If so, get rid of it before we start a new operation + // that wants to fill the output BIO. + int produced = drain_outgoing_network_bytes_if_any(to_peer, to_peer_buf_size); + if (produced > 0) { + // Handshake isn't complete yet and we've got stuff to send. Need to continue handshake + // once more data is available from the peer. + return handshake_produced_bytes_and_needs_more_peer_data(static_cast<size_t>(produced)); + } else if (produced < 0) { + return handshake_failed(); + } + const auto consume_res = do_handshake_and_consume_peer_input_bytes(from_peer, from_peer_buf_size); + LOG_ASSERT(consume_res.bytes_produced == 0); + if (consume_res.failed()) { + return consume_res; + } + // SSL_do_handshake() might have produced more data to send. Note: handshake may + // be complete at this point. + produced = drain_outgoing_network_bytes_if_any(to_peer, to_peer_buf_size); + if (produced < 0) { + return handshake_failed(); + } + return handshaked_bytes(consume_res.bytes_consumed, static_cast<size_t>(produced), consume_res.state); +} + +HandshakeResult OpenSslCryptoCodecImpl::do_handshake_and_consume_peer_input_bytes( + const char *from_peer, size_t from_peer_buf_size) noexcept { + // Feed the SSL session input in frame-sized chunks between each call to SSL_do_handshake(). + // This is primarily to ensure we don't shove unbounded amounts of data into the BIO + // in the case that someone naughty is sending us tons of garbage over the socket. + size_t consumed_total = 0; + while (true) { + // Assumption: SSL_do_handshake will place all required outgoing handshake + // data in the output memory BIO without requiring WANT_WRITE. Freestanding + // memory BIOs are _supposedly_ auto-resizing, so this should work transparently. + // At the very least, if this is not the case we'll auto-fail the connection + // and quickly find out..! + // TODO test multi-frame sized handshake + // TODO should we invoke ::ERR_clear_error() prior? + int ssl_result = ::SSL_do_handshake(_ssl.get()); + ssl_result = ::SSL_get_error(_ssl.get(), ssl_result); + + if (ssl_result == SSL_ERROR_WANT_READ) { + LOG(spam, "SSL_do_handshake() returned SSL_ERROR_WANT_READ"); + if (from_peer_buf_size - consumed_total > 0) { + int consumed = ::BIO_write(_input_bio, from_peer + consumed_total, + static_cast<int>(std::min(MaximumTlsFrameSize, from_peer_buf_size - consumed_total))); + LOG(spam, "BIO_write copied in %d bytes of ciphertext to _input_bio", consumed); + if (consumed < 0) { + LOG(error, "Memory BIO_write() returned %d", consumed); // TODO BIO_need_retry? + return handshake_failed(); + } + consumed_total += consumed; // TODO protect against consumed == 0? + continue; + } else { + return handshake_consumed_bytes_and_needs_more_peer_data(consumed_total); + } + } else if (ssl_result == SSL_ERROR_NONE) { + // At this point SSL_do_handshake has stated it does not need any more peer data, i.e. + // the handshake is complete. + if (!SSL_is_init_finished(_ssl.get())) { + LOG(error, "SSL handshake is not completed even though no more peer data is requested"); + return handshake_failed(); + } + return handshake_consumed_bytes_and_is_complete(consumed_total); + } else { + LOG(error, "SSL_do_handshake() returned unexpected error: %s", ssl_error_to_str(ssl_result)); + return handshake_failed(); + } + }; +} + +EncodeResult OpenSslCryptoCodecImpl::encode(const char* plaintext, size_t plaintext_size, + char* ciphertext, size_t ciphertext_size) noexcept { + LOG_ASSERT(verify_buf(plaintext, plaintext_size) && verify_buf(ciphertext, ciphertext_size)); + + if (!SSL_is_init_finished(_ssl.get())) { + LOG(error, "OpenSslCryptoCodecImpl::encode() called before handshake completed"); + return encode_failed(); + } + size_t bytes_consumed = 0; + if (plaintext_size != 0) { + int to_consume = static_cast<int>(std::min(plaintext_size, MaximumFramePlaintextSize)); + // SSL_write encodes plaintext to ciphertext and writes to _output_bio + int consumed = ::SSL_write(_ssl.get(), plaintext, to_consume); + LOG(spam, "After SSL_write() -> %d, _input_bio pending=%d, _output_bio pending=%d", + consumed, BIO_pending(_input_bio), BIO_pending(_output_bio)); + if (consumed < 0) { + int ssl_error = ::SSL_get_error(_ssl.get(), consumed); + LOG(error, "SSL_write() failed to write frame, got error %s", ssl_error_to_str(ssl_error)); + // TODO explicitly detect and log TLS renegotiation error (SSL_ERROR_WANT_READ)? + return encode_failed(); + } else if (consumed != to_consume) { + LOG(error, "SSL_write() returned OK but did not consume all requested plaintext"); + return encode_failed(); + } + bytes_consumed = static_cast<size_t>(consumed); + } + + int produced = drain_outgoing_network_bytes_if_any(ciphertext, ciphertext_size); + if (produced < 0) { + return encode_failed(); + } + if (BIO_pending(_output_bio) != 0) { + LOG(error, "Residual data left in output BIO on encode(); provided buffer is too small"); + return encode_failed(); + } + return encoded_bytes(bytes_consumed, static_cast<size_t>(produced)); +} +DecodeResult OpenSslCryptoCodecImpl::decode(const char* ciphertext, size_t ciphertext_size, + char* plaintext, size_t plaintext_size) noexcept { + LOG_ASSERT(verify_buf(ciphertext, ciphertext_size) && verify_buf(plaintext, plaintext_size)); + + if (!SSL_is_init_finished(_ssl.get())) { + LOG(error, "OpenSslCryptoCodecImpl::decode() called before handshake completed"); + return decode_failed(); + } + auto produce_res = drain_and_produce_plaintext_from_ssl(plaintext, static_cast<int>(plaintext_size)); + if ((produce_res.bytes_produced > 0) || produce_res.failed()) { + return produce_res; // TODO gRPC [1] handles this differently... allows fallthrough + } + int consumed = consume_peer_input_bytes(ciphertext, ciphertext_size); + if (consumed < 0) { + return decode_failed(); + } + produce_res = drain_and_produce_plaintext_from_ssl(plaintext, static_cast<int>(plaintext_size)); + return decoded_bytes(static_cast<size_t>(consumed), produce_res.bytes_produced, produce_res.state); +} + +DecodeResult OpenSslCryptoCodecImpl::drain_and_produce_plaintext_from_ssl( + char* plaintext, size_t plaintext_size) noexcept { + // SSL_read() is named a bit confusingly. We read _from_ the SSL-internal state + // via the input BIO _into_ to the receiving plaintext buffer. + // This may consume the entire, parts of, or none of the input BIO's data, + // depending on how much TLS frame data is available and its size relative + // to the receiving plaintext buffer. + int produced = ::SSL_read(_ssl.get(), plaintext, static_cast<int>(plaintext_size)); + LOG(spam, "After SSL_read() -> %d, _input_bio pending=%d, _output_bio pending=%d", + produced, BIO_pending(_input_bio), BIO_pending(_output_bio)); + if (produced > 0) { + // At least 1 frame decoded successfully. + return decoded_frames_with_plaintext_bytes(static_cast<size_t>(produced)); + } else { + int ssl_error = ::SSL_get_error(_ssl.get(), produced); + switch (ssl_error) { + case SSL_ERROR_WANT_READ: + // SSL_read() was not able to decode a full frame with the ciphertext that + // we've fed it thus far; caller must feed it some and then try again. + LOG(spam, "SSL_read() returned SSL_ERROR_WANT_READ, must get more ciphertext"); + return decode_needs_more_peer_data(); + default: + LOG(error, "SSL_read() returned unexpected error: %s", ssl_error_to_str(ssl_error)); + return decode_failed(); + } + } +} + +int OpenSslCryptoCodecImpl::consume_peer_input_bytes( + const char* ciphertext, size_t ciphertext_size) noexcept { + // TODO BIO_need_retry on failure? Can this even happen for memory BIOs? + int consumed = ::BIO_write(_input_bio, ciphertext, static_cast<int>(std::min(MaximumTlsFrameSize, ciphertext_size))); + LOG(spam, "BIO_write copied in %d bytes of ciphertext to _input_bio", consumed); + if (consumed < 0) { + LOG(error, "Memory BIO_write() returned %d", consumed); + } + return consumed; +} + +} + +// External references: +// [0] http://openssl.6102.n7.nabble.com/nonblocking-implementation-question-tp1728p1732.html +// [1] https://github.com/grpc/grpc/blob/master/src/core/tsi/ssl_transport_security.cc diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h new file mode 100644 index 00000000000..44ca8859596 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h @@ -0,0 +1,76 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "openssl_typedefs.h" +#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <vespa/vespalib/net/tls/crypto_codec.h> +#include <memory> + +namespace vespalib::net::tls { class TlsContext; } + +namespace vespalib::net::tls::impl { + +/* + * Frame-level OpenSSL-backed TLSv1.2 crypto codec implementation. + * + * Currently has sub-optimal buffer management, and is mostly intended + * as a starting point. + * + * NOT thread safe per instance, but independent instances may be + * used by different threads safely. + */ +class OpenSslCryptoCodecImpl : public CryptoCodec { + SslPtr _ssl; + ::BIO* _input_bio; // Owned by _ssl + ::BIO* _output_bio; // Owned by _ssl + Mode _mode; +public: + OpenSslCryptoCodecImpl(::SSL_CTX& ctx, Mode mode); + + /* + * From RFC 8449 (Record Size Limit Extension for TLS), section 1: + * "TLS versions 1.2 [RFC5246] and earlier permit senders to + * generate records 16384 octets in size, plus any expansion + * from compression and protection up to 2048 octets (though + * typically this expansion is only 16 octets). TLS 1.3 reduces + * the allowance for expansion to 256 octets." + * + * We're on TLSv1.2, so make room for the worst case. + */ + static constexpr size_t MaximumTlsFrameSize = 16384 + 2048; + static constexpr size_t MaximumFramePlaintextSize = 16384; + + size_t min_encode_buffer_size() const noexcept override { + return MaximumTlsFrameSize; + } + size_t min_decode_buffer_size() const noexcept override { + return MaximumFramePlaintextSize; + } + + HandshakeResult handshake(const char* from_peer, size_t from_peer_buf_size, + char* to_peer, size_t to_peer_buf_size) noexcept override; + + EncodeResult encode(const char* plaintext, size_t plaintext_size, + char* ciphertext, size_t ciphertext_size) noexcept override; + DecodeResult decode(const char* ciphertext, size_t ciphertext_size, + char* plaintext, size_t plaintext_size) noexcept override; +private: + /* + * Returns + * n > 0 if n bytes written to `to_peer`. Always <= to_peer_buf_size + * n == 0 if no bytes pending in output BIO + * n < 0 on error + */ + int drain_outgoing_network_bytes_if_any(char *to_peer, size_t to_peer_buf_size) noexcept; + /* + * Returns + * n > 0 if n bytes written to `ciphertext`. Always <= ciphertext_size + * n == 0 if no bytes pending in input BIO + * n < 0 on error + */ + int consume_peer_input_bytes(const char* ciphertext, size_t ciphertext_size) noexcept; + HandshakeResult do_handshake_and_consume_peer_input_bytes(const char *from_peer, size_t from_peer_buf_size) noexcept; + DecodeResult drain_and_produce_plaintext_from_ssl(char* plaintext, size_t plaintext_size) noexcept; +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.cpp b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.cpp new file mode 100644 index 00000000000..27250dd43fc --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.cpp @@ -0,0 +1,269 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "openssl_typedefs.h" +#include "openssl_tls_context_impl.h" +#include <vespa/vespalib/net/tls/crypto_exception.h> +#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <mutex> +#include <vector> +#include <memory> +#include <stdexcept> +#include <openssl/ssl.h> +#include <openssl/crypto.h> +#include <openssl/err.h> +#include <openssl/pem.h> + +#include <vespa/log/log.h> +LOG_SETUP(".vespalib.net.tls.openssl_tls_context_impl"); + +#if (OPENSSL_VERSION_NUMBER < 0x10000000L) +// < 1.0 requires explicit thread ID callback support. +# error "Provided OpenSSL version is too darn old, need at least 1.0" +#endif + +namespace vespalib::net::tls::impl { + +namespace { + +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) + +std::vector<std::unique_ptr<std::mutex>> _g_mutexes; + +// Some works on OpenSSL legacy locking: OpenSSL does not implement locking +// itself internally, deferring to user code callbacks that Do The Needful(tm). +// The `n` parameter refers to the nth mutex, which is always < CRYPTO_num_locks(). +void openssl_locking_cb(int mode, int n, [[maybe_unused]] const char *file, [[maybe_unused]] int line) { + if (mode & CRYPTO_LOCK) { + _g_mutexes[n]->lock(); + } else { + _g_mutexes[n]->unlock(); + } +} + +#endif + +struct OpenSslLibraryResources { + OpenSslLibraryResources(); + ~OpenSslLibraryResources(); +}; + +OpenSslLibraryResources::OpenSslLibraryResources() { + // Other implementations (Asio, gRPC) disagree on whether main library init + // itself should take place on >= v1.1. We always do it to be on the safe side..! + ::SSL_library_init(); + ::SSL_load_error_strings(); + ::OpenSSL_add_all_algorithms(); + // Luckily, the mutex callback madness is not present on >= v1.1 +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) + // Since the init path should happen only once globally, but multiple libraries + // may use OpenSSL, make sure we don't step on any toes if locking callbacks are + // already set up. + if (!::CRYPTO_get_locking_callback()) { + const int num_locks = ::CRYPTO_num_locks(); + LOG_ASSERT(num_locks > 0); + _g_mutexes.reserve(num_locks); + for (int i = 0; i < num_locks; ++i) { + _g_mutexes.emplace_back(std::make_unique<std::mutex>()); + } + ::CRYPTO_set_locking_callback(openssl_locking_cb); + } +#endif +} + +OpenSslLibraryResources::~OpenSslLibraryResources() { +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) + if (::CRYPTO_get_locking_callback() == openssl_locking_cb) { + ::CRYPTO_set_locking_callback(nullptr); + } +#endif + ERR_free_strings(); + EVP_cleanup(); + CRYPTO_cleanup_all_ex_data(); +} + +// TODO make global init instead..? +void ensure_openssl_initialized_once() { + static OpenSslLibraryResources openssl_resources; + (void) openssl_resources; +} + +BioPtr bio_from_string(vespalib::stringref str) { + LOG_ASSERT(str.size() <= INT_MAX); +#if (OPENSSL_VERSION_NUMBER >= 0x10002000L) + BioPtr bio(::BIO_new_mem_buf(str.data(), static_cast<int>(str.size()))); +#else + BioPtr bio(::BIO_new_mem_buf(const_cast<char*>(str.data()), static_cast<int>(str.size()))); +#endif + if (!bio) { + throw CryptoException("BIO_new_mem_buf"); + } + return bio; +} + +// Several OpenSSL functions take a magical user passphrase argument with +// potentially horrible default behavior for password protected input. +// +// From OpenSSL docs (https://www.openssl.org/docs/man1.1.0/crypto/PEM_read_bio_PrivateKey.html): +// +// "If the cb parameters is set to NULL and the u parameter is not NULL +// then the u parameter is interpreted as a null terminated string to use +// as the passphrase. If both cb and u are NULL then the default callback +// routine is used which will typically prompt for the passphrase on the +// current terminal with echoing turned off." +// +// Neat! +// +// Bonus points for being non-const as well. +constexpr inline void *empty_passphrase() { + return const_cast<void *>(static_cast<const void *>("")); +} + +// Attempt to read a PEM encoded (trusted) certificate from the given BIO. +// BIO might contain further certificates if function returns non-nullptr. +// Returns nullptr if no certificate could be loaded. This is usually an error, +// as this should be the first certificate in the chain. +X509Ptr read_trusted_x509_from_bio(::BIO& bio) { + // "_AUX" means the certificate is trusted. Why they couldn't name this function + // something with "trusted" instead is left as an exercise to the reader. + return X509Ptr(::PEM_read_bio_X509_AUX(&bio, nullptr, nullptr, empty_passphrase())); +} + +// Attempt to read a PEM encoded certificate from the given BIO. +// BIO might contain further certificates if function returns non-nullptr. +// Returns nullptr if no certificate could be loaded. This usually implies +// that there are no more certificates left in the chain. +X509Ptr read_untrusted_x509_from_bio(::BIO& bio) { + return X509Ptr(::PEM_read_bio_X509(&bio, nullptr, nullptr, empty_passphrase())); +} + +SslCtxPtr new_tls_ctx_with_auto_init() { + ensure_openssl_initialized_once(); +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) + return SslCtxPtr(::SSL_CTX_new(::TLSv1_2_method())); +#else + SslCtxPtr ctx(::SSL_CTX_new(::TLS_method())); + if (!::SSL_CTX_set_min_proto_version(ctx.get(), TLS1_2_VERSION)) { + throw CryptoException("SSL_CTX_set_min_proto_version"); + } + return ctx; +#endif +} + +} // anon ns + +OpenSslTlsContextImpl::OpenSslTlsContextImpl(const TransportSecurityOptions& ts_opts) + : _ctx(new_tls_ctx_with_auto_init()) +{ + if (!_ctx) { + throw CryptoException("Failed to create new TLS context"); + } + add_certificate_authorities(ts_opts.ca_certs_pem()); + add_certificate_chain(ts_opts.cert_chain_pem()); + use_private_key(ts_opts.private_key_pem()); + verify_private_key(); + enable_ephemeral_key_exchange(); + disable_compression(); + enforce_peer_certificate_verification(); + // TODO set accepted cipher suites! + // TODO `--> If not set in options, use Modern spec from https://wiki.mozilla.org/Security/Server_Side_TLS +} + +OpenSslTlsContextImpl::~OpenSslTlsContextImpl() = default; + +void OpenSslTlsContextImpl::add_certificate_authorities(vespalib::stringref ca_pem) { + // TODO support empty CA set...? Ever useful? + auto bio = bio_from_string(ca_pem); + ::X509_STORE* cert_store = ::SSL_CTX_get_cert_store(_ctx.get()); // Internal pointer, not owned by us. + while (true) { + auto ca_cert = read_untrusted_x509_from_bio(*bio); + if (!ca_cert) { + break; + } + if (::X509_STORE_add_cert(cert_store, ca_cert.get()) != 1) { // Does _not_ take ownership + throw CryptoException("X509_STORE_add_cert"); + } + } +} + +void OpenSslTlsContextImpl::add_certificate_chain(vespalib::stringref chain_pem) { + ::ERR_clear_error(); + auto bio = bio_from_string(chain_pem); + // First certificate in the chain is the node's own (trusted) certificate. + auto own_cert = read_trusted_x509_from_bio(*bio); + if (!own_cert) { + throw CryptoException("No X509 certificates could be found in provided chain"); + } + // Ownership of certificate is _not_ transferred, OpenSSL makes internal copy. + // This is not well documented, but is mentioned by other impls. + if (::SSL_CTX_use_certificate(_ctx.get(), own_cert.get()) != 1) { + throw CryptoException("SSL_CTX_use_certificate"); + } + // After the node's own certificate comes any intermediate CA-provided certificates. + while (true) { + auto ca_cert = read_untrusted_x509_from_bio(*bio); + if (!ca_cert) { + // No more certificates in chain, hooray! + ::ERR_clear_error(); + break; + } + // Ownership of certificate _is_ transferred here! + if (!::SSL_CTX_add_extra_chain_cert(_ctx.get(), ca_cert.release())) { + throw CryptoException("SSL_CTX_add_extra_chain_cert"); + } + } +} + +void OpenSslTlsContextImpl::use_private_key(vespalib::stringref key_pem) { + auto bio = bio_from_string(key_pem); + EvpPkeyPtr key(::PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, empty_passphrase())); + if (!key) { + throw CryptoException("Failed to read PEM private key data"); + } + // Ownership _not_ taken. + if (::SSL_CTX_use_PrivateKey(_ctx.get(), key.get()) != 1) { + throw CryptoException("SSL_CTX_use_PrivateKey"); + } +} + +void OpenSslTlsContextImpl::verify_private_key() { + if (::SSL_CTX_check_private_key(_ctx.get()) != 1) { + throw CryptoException("SSL_CTX_check_private_key failed; mismatch between public and private key?"); + } +} + +void OpenSslTlsContextImpl::enable_ephemeral_key_exchange() { +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) +# if (OPENSSL_VERSION_NUMBER >= 0x10002000L) + // Always enabled by default on higher versions. + // Auto curve selection is preferred over using SSL_CTX_set_ecdh_tmp + if (!::SSL_CTX_set_ecdh_auto(_ctx.get(), 1)) { + throw CryptoException("SSL_CTX_set_ecdh_auto"); + } + // New ECDH key per connection. + ::SSL_CTX_set_options(_ctx.get(), SSL_OP_SINGLE_ECDH_USE); +# else + // Set explicit P-256 curve used for ECDH purposes. + EcKeyPtr ec_curve(::EC_KEY_new_by_curve_name(NID_X9_62_prime256v1)); + if (!ec_curve) { + throw CryptoException("EC_KEY_new_by_curve_name(NID_X9_62_prime256v1)"); + } + if (!::SSL_CTX_set_tmp_ecdh(_ctx.get(), ec_curve.get())) { + throw CryptoException("SSL_CTX_set_tmp_ecdh"); + } +# endif +#endif +} + +void OpenSslTlsContextImpl::disable_compression() { + // TLS stream compression is vulnerable to a host of chosen plaintext + // attacks (CRIME, BREACH etc), so disable it. + ::SSL_CTX_set_options(_ctx.get(), SSL_OP_NO_COMPRESSION); +} + +void OpenSslTlsContextImpl::enforce_peer_certificate_verification() { + // We require full mutual certificate verification. No way to configure + // out of this, at least not for the time being. + // TODO verification callback for custom CN/SAN etc checks. + SSL_CTX_set_verify(_ctx.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h new file mode 100644 index 00000000000..72f9f3b570d --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h @@ -0,0 +1,29 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "openssl_typedefs.h" +#include <vespa/vespalib/net/tls/tls_context.h> +#include <vespa/vespalib/stllike/string.h> + +namespace vespalib::net::tls::impl { + +class OpenSslTlsContextImpl : public TlsContext { + SslCtxPtr _ctx; +public: + explicit OpenSslTlsContextImpl(const TransportSecurityOptions&); + ~OpenSslTlsContextImpl() override; + + ::SSL_CTX* native_context() const noexcept { return _ctx.get(); } +private: + // Note: single use per instance; does _not_ clear existing chain! + void add_certificate_authorities(stringref ca_pem); + void add_certificate_chain(stringref chain_pem); + void use_private_key(stringref key_pem); + void verify_private_key(); + // Enable use of ephemeral key exchange (ECDHE), allowing forward secrecy. + void enable_ephemeral_key_exchange(); + void disable_compression(); + void enforce_peer_certificate_verification(); +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_typedefs.h b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_typedefs.h new file mode 100644 index 00000000000..afafe556338 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_typedefs.h @@ -0,0 +1,53 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <memory> +#include <openssl/ssl.h> +#include <openssl/crypto.h> +#include <openssl/x509.h> + +namespace vespalib::net::tls::impl { + +struct BioDeleter { + void operator()(::BIO* bio) const noexcept { + ::BIO_free(bio); + } +}; +using BioPtr = std::unique_ptr<::BIO, BioDeleter>; + +struct SslDeleter { + void operator()(::SSL* ssl) const noexcept { + ::SSL_free(ssl); + } +}; +using SslPtr = std::unique_ptr<::SSL, SslDeleter>; + +struct SslCtxDeleter { + void operator()(::SSL_CTX* ssl) const noexcept { + ::SSL_CTX_free(ssl); + } +}; +using SslCtxPtr = std::unique_ptr<::SSL_CTX, SslCtxDeleter>; + +struct X509Deleter { + void operator()(::X509* cert) const noexcept { + ::X509_free(cert); + } +}; +using X509Ptr = std::unique_ptr<::X509, X509Deleter>; + +struct EvpPkeyDeleter { + void operator()(::EVP_PKEY* pkey) const noexcept { + ::EVP_PKEY_free(pkey); + } +}; +using EvpPkeyPtr = std::unique_ptr<::EVP_PKEY, EvpPkeyDeleter>; + +struct EcKeyDeleter { + void operator()(::EC_KEY* ec_key) const noexcept { + ::EC_KEY_free(ec_key); + } +}; +using EcKeyPtr = std::unique_ptr<::EC_KEY, EcKeyDeleter>; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/tls_context.cpp b/vespalib/src/vespa/vespalib/net/tls/tls_context.cpp new file mode 100644 index 00000000000..467838975e7 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/tls_context.cpp @@ -0,0 +1,11 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "tls_context.h" +#include <vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h> + +namespace vespalib::net::tls { + +std::unique_ptr<TlsContext> TlsContext::create_default_context(const TransportSecurityOptions& opts) { + return std::make_unique<impl::OpenSslTlsContextImpl>(opts); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/tls_context.h b/vespalib/src/vespa/vespalib/net/tls/tls_context.h new file mode 100644 index 00000000000..7292f43f88c --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/tls_context.h @@ -0,0 +1,16 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <memory> + +namespace vespalib::net::tls { + +class TransportSecurityOptions; + +struct TlsContext { + virtual ~TlsContext() = default; + + static std::unique_ptr<TlsContext> create_default_context(const TransportSecurityOptions&); +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.cpp b/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.cpp new file mode 100644 index 00000000000..72d9eacf37c --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.cpp @@ -0,0 +1,22 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "tls_crypto_engine.h" +#include "crypto_codec.h" +#include "crypto_codec_adapter.h" + +namespace vespalib { + +TlsCryptoEngine::TlsCryptoEngine(net::tls::TransportSecurityOptions tls_opts) + : _tls_ctx(net::tls::TlsContext::create_default_context(tls_opts)) +{ +} + +CryptoSocket::UP +TlsCryptoEngine::create_crypto_socket(SocketHandle socket, bool is_server) +{ + auto mode = is_server ? net::tls::CryptoCodec::Mode::Server : net::tls::CryptoCodec::Mode::Client; + auto codec = net::tls::CryptoCodec::create_default_codec(*_tls_ctx, mode); + return std::make_unique<net::tls::CryptoCodecAdapter>(std::move(socket), std::move(codec)); +} + +} // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.h b/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.h new file mode 100644 index 00000000000..58fda2b3b21 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.h @@ -0,0 +1,23 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/net/crypto_engine.h> +#include "transport_security_options.h" +#include "tls_context.h" + +namespace vespalib { + +/** + * Crypto engine implementing TLS. + **/ +class TlsCryptoEngine : public CryptoEngine +{ +private: + std::unique_ptr<net::tls::TlsContext> _tls_ctx; +public: + TlsCryptoEngine(net::tls::TransportSecurityOptions tls_opts); + CryptoSocket::UP create_crypto_socket(SocketHandle socket, bool is_server) override; +}; + +} // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/net/tls/transport_security_options.cpp b/vespalib/src/vespa/vespalib/net/tls/transport_security_options.cpp new file mode 100644 index 00000000000..4e39fe4d7fa --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/transport_security_options.cpp @@ -0,0 +1,12 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "transport_security_options.h" +#include <openssl/crypto.h> + +namespace vespalib::net::tls { + +TransportSecurityOptions::~TransportSecurityOptions() { + OPENSSL_cleanse(&_private_key_pem[0], _private_key_pem.size()); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/transport_security_options.h b/vespalib/src/vespa/vespalib/net/tls/transport_security_options.h new file mode 100644 index 00000000000..0a228388791 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/transport_security_options.h @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/stllike/string.h> + +namespace vespalib::net::tls { + +class TransportSecurityOptions { + vespalib::string _ca_certs_pem; + vespalib::string _cert_chain_pem; + vespalib::string _private_key_pem; +public: + TransportSecurityOptions() = default; + + TransportSecurityOptions(vespalib::string ca_certs_pem, + vespalib::string cert_chain_pem, + vespalib::string private_key_pem) + : _ca_certs_pem(std::move(ca_certs_pem)), + _cert_chain_pem(std::move(cert_chain_pem)), + _private_key_pem(std::move(private_key_pem)) + {} + ~TransportSecurityOptions(); + + const vespalib::string& ca_certs_pem() const noexcept { return _ca_certs_pem; } + const vespalib::string& cert_chain_pem() const noexcept { return _cert_chain_pem; } + const vespalib::string& private_key_pem() const noexcept { return _private_key_pem; } +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.cpp b/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.cpp new file mode 100644 index 00000000000..05cfc797e51 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.cpp @@ -0,0 +1,102 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "transport_security_options_reading.h" +#include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/io/fileutil.h> +#include <vespa/vespalib/io/mapped_file_input.h> +#include <vespa/vespalib/data/memory_input.h> + +namespace vespalib::net::tls { + +/* + + Proposed JSON format for TLS configuration file: + +{ + "files": { + "private-key": "myhost.key", + "ca-certificates": "my_cas.pem", + "certificates": "certs.pem" + }, + // for later: + "peer-taggers": [ + { + "requirements":[ + { + "field": "SAN" + "must-match": "DNS:foo.bar.baz.*" + } + ], + "tags": ["cluster-peers", "config-server"] // or "roles"? Avoid ambiguities with Athenz concepts + }, + { + "requirements":[ + { "field":"CN", "must-match": "config.blarg.*"} + ], + "tags": ["config-server"] + } + ] +} + + */ + +using namespace slime::convenience; + +namespace { + +constexpr const char* files_field = "files"; +constexpr const char* private_key_field = "private-key"; +constexpr const char* ca_certs_field = "ca-certificates"; +constexpr const char* certs_field = "certificates"; + +void verify_referenced_file_exists(const vespalib::string& file_path) { + if (!fileExists(file_path)) { + throw IllegalArgumentException(make_string("File '%s' referenced by TLS config does not exist", file_path.c_str())); + } +} + +vespalib::string load_file_referenced_by_field(const Cursor& cursor, const char* field) { + auto file_path = cursor[field].asString().make_string(); + if (file_path.empty()) { + throw IllegalArgumentException(make_string("TLS config field '%s' has not been set", field)); + } + verify_referenced_file_exists(file_path); + return File::readAll(file_path); +} + +std::unique_ptr<TransportSecurityOptions> load_from_input(Input& input) { + Slime root; + auto parsed = slime::JsonFormat::decode(input, root); + if (parsed == 0) { + throw IllegalArgumentException("Provided TLS config file is not valid JSON"); + } + auto& files = root[files_field]; + if (files.fields() == 0) { + throw IllegalArgumentException("TLS config root field 'files' is missing or empty"); + } + // Note: we do no look at the _contents_ of the files; this is deferred to the + // TLS context code which actually tries to extract key and certificate material + // from them. + auto ca_certs = load_file_referenced_by_field(files, ca_certs_field); + auto certs = load_file_referenced_by_field(files, certs_field); + auto priv_key = load_file_referenced_by_field(files, private_key_field); + + return std::make_unique<TransportSecurityOptions>(std::move(ca_certs), std::move(certs), std::move(priv_key)); +} + +} // anon ns + +std::unique_ptr<TransportSecurityOptions> read_options_from_json_string(const vespalib::string& json_data) { + MemoryInput file_input(json_data); + return load_from_input(file_input); +} + +std::unique_ptr<TransportSecurityOptions> read_options_from_json_file(const vespalib::string& file_path) { + MappedFileInput file_input(file_path); + if (!file_input.valid()) { + throw IllegalArgumentException(make_string("TLS config file '%s' could not be read", file_path.c_str())); + } + return load_from_input(file_input); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.h b/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.h new file mode 100644 index 00000000000..800b3b5ed0d --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.h @@ -0,0 +1,20 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "transport_security_options.h" +#include <memory> + +namespace vespalib::net::tls { + +// TODO consider renaming TransportSecurityOptions -> TlsConfig + +/** + * Throws IoException if file_path or any files referenced by it can't be accessed + * Throws IllegalArgumentException if file is not parseable as a valid TLS config file or + * if mandatory JSON fields are missing or incomplete. + */ +std::unique_ptr<TransportSecurityOptions> read_options_from_json_file(const vespalib::string& file_path); +// Same properties as read_options_from_json_file() +std::unique_ptr<TransportSecurityOptions> read_options_from_json_string(const vespalib::string& json_data); + +} diff --git a/vespalib/src/vespa/vespalib/test/CMakeLists.txt b/vespalib/src/vespa/vespalib/test/CMakeLists.txt index 4c2c65e8793..4eb47735ca7 100644 --- a/vespalib/src/vespa/vespalib/test/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/test/CMakeLists.txt @@ -1,5 +1,6 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -vespa_add_library(vespalib_vespalib_test INTERFACE +vespa_add_library(vespalib_vespalib_test OBJECT SOURCES + make_tls_options_for_testing.cpp DEPENDS ) diff --git a/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.cpp b/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.cpp new file mode 100644 index 00000000000..e70914dec2f --- /dev/null +++ b/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.cpp @@ -0,0 +1,77 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "make_tls_options_for_testing.h" + +/* + * Generated with the following commands: + * + * openssl ecparam -name prime256v1 -genkey -out ca.key + * + * openssl req -new -x509 -nodes -key ca.key \ + * -sha256 -out ca.pem \ + * -subj '/C=US/L=LooneyVille/O=ACME/OU=ACME test CA/CN=acme.example.com' \ + * -days 10000 + * + * openssl ecparam -name prime256v1 -genkey -out host.key + * + * openssl req -new -key host.key -out host.csr \ + * -subj '/C=US/L=LooneyVille/O=Wile. E. Coyote, Ltd./CN=wile.example.com' \ + * -sha256 + * + * openssl x509 -req -in host.csr \ + * -CA ca.pem \ + * -CAkey ca.key \ + * -CAcreateserial \ + * -out host.pem \ + * -days 10000 \ + * -sha256 + * + * TODO generate keypairs and certs at test-time to avoid any hard-coding + * There certs are valid until 2046, so that buys us some time..! + */ + +// ca.pem +constexpr const char* ca_pem = R"(-----BEGIN CERTIFICATE----- +MIIBuDCCAV4CCQDpVjQIixTxvDAKBggqhkjOPQQDAjBkMQswCQYDVQQGEwJVUzEU +MBIGA1UEBwwLTG9vbmV5VmlsbGUxDTALBgNVBAoMBEFDTUUxFTATBgNVBAsMDEFD +TUUgdGVzdCBDQTEZMBcGA1UEAwwQYWNtZS5leGFtcGxlLmNvbTAeFw0xODA4MzEx +MDU3NDVaFw00NjAxMTYxMDU3NDVaMGQxCzAJBgNVBAYTAlVTMRQwEgYDVQQHDAtM +b29uZXlWaWxsZTENMAsGA1UECgwEQUNNRTEVMBMGA1UECwwMQUNNRSB0ZXN0IENB +MRkwFwYDVQQDDBBhY21lLmV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0D +AQcDQgAE1L7IzCN5pbyVnBATIHieuxq+hf9kWyn5yfjkXMhD52T5ITz1huq4nbiN +YtRoRP7XmipI60R/uiCHzERcsVz4rDAKBggqhkjOPQQDAgNIADBFAiEA6wmZDBca +y0aJ6ABtjbjx/vlmVDxdkaSZSgO8h2CkvIECIFktCkbZhDFfSvbqUScPOGuwkdGQ +L/EW2Bxp+1BPcYoZ +-----END CERTIFICATE-----)"; + +// host.pem +constexpr const char* cert_pem = R"(-----BEGIN CERTIFICATE----- +MIIBsTCCAVgCCQD6GfDh0ltpsjAKBggqhkjOPQQDAjBkMQswCQYDVQQGEwJVUzEU +MBIGA1UEBwwLTG9vbmV5VmlsbGUxDTALBgNVBAoMBEFDTUUxFTATBgNVBAsMDEFD +TUUgdGVzdCBDQTEZMBcGA1UEAwwQYWNtZS5leGFtcGxlLmNvbTAeFw0xODA4MzEx +MDU3NDVaFw00NjAxMTYxMDU3NDVaMF4xCzAJBgNVBAYTAlVTMRQwEgYDVQQHDAtM +b29uZXlWaWxsZTEeMBwGA1UECgwVV2lsZS4gRS4gQ295b3RlLCBMdGQuMRkwFwYD +VQQDDBB3aWxlLmV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE +e+Y4hxt66em0STviGUj6ZDbxzoLoubXWRml8JDFrEc2S2433KWw2npxYSKVCyo3a +/Vo33V8/H0WgOXioKEZJxDAKBggqhkjOPQQDAgNHADBEAiAN+87hQuGv3z0Ja2BV +b8PHq2vp3BJHjeMuxWu4BFPn0QIgYlvIHikspgGatXRNMZ1gPC0oCccsJFcie+Cw +zL06UPI= +-----END CERTIFICATE-----)"; + +// host.key +constexpr const char* key_pem = R"(-----BEGIN EC PARAMETERS----- +BggqhkjOPQMBBw== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEID6di2PFYn8hPrxPbkFDGkSqF+K8L520In7nx3g0jwzOoAoGCCqGSM49 +AwEHoUQDQgAEe+Y4hxt66em0STviGUj6ZDbxzoLoubXWRml8JDFrEc2S2433KWw2 +npxYSKVCyo3a/Vo33V8/H0WgOXioKEZJxA== +-----END EC PRIVATE KEY-----)"; + +namespace vespalib::test { + +vespalib::net::tls::TransportSecurityOptions make_tls_options_for_testing() { + return vespalib::net::tls::TransportSecurityOptions(ca_pem, cert_pem, key_pem); +} + +} // namespace vespalib::test diff --git a/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.h b/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.h new file mode 100644 index 00000000000..a1f1d5958f9 --- /dev/null +++ b/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.h @@ -0,0 +1,15 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/net/tls/transport_security_options.h> + +namespace vespalib::test { + +/** + * Make security options allowing you to talk to yourself using + * TLS. This is intended for testing purposes only. + **/ +vespalib::net::tls::TransportSecurityOptions make_tls_options_for_testing(); + +} // namespace vespalib::test diff --git a/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h b/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h index 9679e6379f5..6e8fa368df7 100644 --- a/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h +++ b/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h @@ -23,10 +23,10 @@ namespace thread { class ThreadInit; } // init function when creating an executor to inject a frame with the // given name into the stack of all worker threads. -#define VESPA_THREAD_STACK_TAG(name) \ - int name(Runnable &worker) { \ - worker.run(); \ - return 1; \ +#define VESPA_THREAD_STACK_TAG(name) \ + int name(::vespalib::Runnable &worker) { \ + worker.run(); \ + return 1; \ } /** diff --git a/vespalog/pom.xml b/vespalog/pom.xml index 6443769afbe..7b167ee2c1c 100644 --- a/vespalog/pom.xml +++ b/vespalog/pom.xml @@ -50,6 +50,10 @@ <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-surefire-plugin</artifactId> <configuration> + <forkMode>once</forkMode> + <environmentVariables> + <VESPA_HOME>${project.build.directory}</VESPA_HOME> + </environmentVariables> <redirectTestOutputToFile>${test.hide}</redirectTestOutputToFile> </configuration> </plugin> diff --git a/vespalog/src/main/java/com/yahoo/log/LogFileDb.java b/vespalog/src/main/java/com/yahoo/log/LogFileDb.java new file mode 100644 index 00000000000..d0fa64805bf --- /dev/null +++ b/vespalog/src/main/java/com/yahoo/log/LogFileDb.java @@ -0,0 +1,50 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.log; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.nio.file.StandardOpenOption.*; + +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import static com.yahoo.vespa.defaults.Defaults.getDefaults; + + +/** + * @author arnej + * + * This class takes care of saving meta-data about a log-file, + * ensuring that we can enact policies about log retention. + **/ +public class LogFileDb { + + static final String DBDIR = "var/db/vespa/logfiledb/"; + + private static long dayStamp() { + long s = System.currentTimeMillis() / 1000; + return s / 100000; + } + + private static OutputStream metaFile() throws java.io.IOException { + String fn = getDefaults().underVespaHome(DBDIR + "logfiles." + dayStamp()); + Path path = Paths.get(fn); + return Files.newOutputStream(path, CREATE, APPEND); + } + + public static void nowLoggingTo(String filename) { + if (filename.contains("\n")) { + throw new IllegalArgumentException("Cannot use filename with newline: "+filename); + } + long s = System.currentTimeMillis() / 1000; + String meta = "" + s + " " + filename + "\n"; + byte[] data = meta.getBytes(UTF_8); + try (OutputStream out = metaFile()) { + out.write(data); + } catch (java.io.IOException e) { + System.err.println("Saving meta-data about logfile "+filename+" failed: "+e); + // ignore + } + } +} diff --git a/vespalog/src/test/java/com/yahoo/log/LogFileDbTest.java b/vespalog/src/test/java/com/yahoo/log/LogFileDbTest.java new file mode 100644 index 00000000000..4dd7bd0978c --- /dev/null +++ b/vespalog/src/test/java/com/yahoo/log/LogFileDbTest.java @@ -0,0 +1,29 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.log; + +import java.io.File; +import static com.yahoo.vespa.defaults.Defaults.getDefaults; +import org.junit.Test; + +/** + * @author arnej + */ +public class LogFileDbTest { + + @Test + public void canSave() { + System.err.println("VH: "+System.getenv("VESPA_HOME")); + File dir = new File(getDefaults().underVespaHome(LogFileDb.DBDIR)); + dir.mkdirs(); + if (dir.isDirectory()) { + System.err.println("using directory: "+dir); + new File(getDefaults().underVespaHome("logs/extra")).mkdirs(); + String fn = getDefaults().underVespaHome("logs/extra/foo-bar.log"); + LogFileDb.nowLoggingTo(fn); + fn = getDefaults().underVespaHome("logs/extra/stamped-1.log"); + LogFileDb.nowLoggingTo(fn); + } else { + System.err.println("cannot create directory: "+dir); + } + } +} diff --git a/vespamalloc/src/vespamalloc/util/osmem.cpp b/vespamalloc/src/vespamalloc/util/osmem.cpp index d7d32f4844a..f4fbb376265 100644 --- a/vespamalloc/src/vespamalloc/util/osmem.cpp +++ b/vespamalloc/src/vespamalloc/util/osmem.cpp @@ -10,7 +10,8 @@ namespace vespamalloc { -void * MmapMemory::reserve(size_t & len) +void * +MmapMemory::reserve(size_t & len) { len = 0; const size_t wLen(0x1000); @@ -20,10 +21,11 @@ void * MmapMemory::reserve(size_t & len) (void) test; setStart(wanted); setEnd(getStart()); - return NULL; + return nullptr; } -size_t findInMemInfo(const char * wanted) +size_t +findInMemInfo(const char * wanted) { size_t value(0); char memInfo[8192]; @@ -34,16 +36,17 @@ size_t findInMemInfo(const char * wanted) assert((sz < int(sizeof(memInfo))) && (sz >= 0)); memInfo[sz] = '\0'; const char * found(strstr(memInfo, wanted)); - if (found != NULL) { + if (found != nullptr) { found += strlen(wanted); - value = strtoul(found, NULL, 0); + value = strtoul(found, nullptr, 0); } close(fd); } return value; } -const char * getToken(const char * & s, const char * e) +const char * +getToken(const char * & s, const char * e) { for (; (s < e) && isspace(s[0]); s++) { } const char * c = s; @@ -51,7 +54,8 @@ const char * getToken(const char * & s, const char * e) return c; } -bool verifyHugePagesMount(const char * mount) +bool +verifyHugePagesMount(const char * mount) { const unsigned int HUGETLBFS_MAGIC(0x958458f6); struct statfs64 st; @@ -70,15 +74,17 @@ MmapMemory::MmapMemory(size_t blockSize) : setupHugePages(); } -void MmapMemory::setupFAdvise() +void +MmapMemory::setupFAdvise() { const char * madv = getenv("VESPA_MALLOC_MADVISE_LIMIT"); if (madv) { - _useMAdvLimit = strtoul(madv, NULL, 0); + _useMAdvLimit = strtoul(madv, nullptr, 0); } } -void MmapMemory::setupHugePages() +void +MmapMemory::setupHugePages() { _hugePagesFileName[0] = '\0'; const char * vespaHugePages = getenv("VESPA_MALLOC_HUGEPAGES"); @@ -140,23 +146,29 @@ MmapMemory::~MmapMemory() } } -void * MmapMemory::get(size_t len) +void * +MmapMemory::get(size_t len) { - void * memory(NULL); + void * memory(nullptr); + int prevErrno = errno; memory = getHugePages(len); - if (memory ==NULL) { + if (memory == nullptr) { + errno = prevErrno; // The temporary error should not impact if the end is good. memory = getNormalPages(len); } return memory; } -void * MmapMemory::getHugePages(size_t len) +void * +MmapMemory::getHugePages(size_t len) { - void * memory(NULL); + void * memory(nullptr); if ( ((len & 0x1fffff) == 0) && len) { + int prevErrno = errno; memory = getBasePages(len, MAP_ANON | MAP_PRIVATE | MAP_HUGETLB, -1, 0); - if (memory == NULL) { + if (memory == nullptr) { if (_hugePagesFd >= 0) { + errno = prevErrno; // The temporary error should not impact if the end is good. memory = getBasePages(len, MAP_SHARED, _hugePagesFd, _hugePagesOffset); if (memory) { _hugePagesOffset += len; @@ -167,21 +179,22 @@ void * MmapMemory::getHugePages(size_t len) return memory; } -void * MmapMemory::getNormalPages(size_t len) +void * +MmapMemory::getNormalPages(size_t len) { return getBasePages(len, MAP_ANON | MAP_PRIVATE, -1, 0); } -void * MmapMemory::getBasePages(size_t len, int mmapOpt, int fd, size_t offset) +void * +MmapMemory::getBasePages(size_t len, int mmapOpt, int fd, size_t offset) { char * wanted = reinterpret_cast<char *>(std::max(reinterpret_cast<size_t>(getEnd()), getMinPreferredStartAddress())); - void * mem(NULL); + void * mem(nullptr); for (bool ok(false) ; !ok && (mem != MAP_FAILED); wanted += getBlockAlignment()) { - if (mem != NULL) { + if (mem != nullptr) { int tmp(munmap(mem, len)); assert(tmp == 0); (void) tmp; - mem = NULL; } // no alignment to _blockSize needed? // both 0x10000000000ul*4 and 0x200000 are multiples of the current block size. @@ -189,7 +202,7 @@ void * MmapMemory::getBasePages(size_t len, int mmapOpt, int fd, size_t offset) ok = (mem == wanted); } if (mem != MAP_FAILED) { - if (getStart() == NULL) { + if (getStart() == nullptr) { setStart(mem); // assumes len parameter is always multiple of the current block size. setEnd(static_cast<char *>(mem)+len); @@ -198,10 +211,11 @@ void * MmapMemory::getBasePages(size_t len, int mmapOpt, int fd, size_t offset) } return mem; } - return NULL; + return nullptr; } -bool MmapMemory::release(void * mem, size_t len) +bool +MmapMemory::release(void * mem, size_t len) { int ret(0); if (_useMAdvLimit <= len) { @@ -214,7 +228,8 @@ bool MmapMemory::release(void * mem, size_t len) return true; } -bool MmapMemory::freeTail(void * mem, size_t len) +bool +MmapMemory::freeTail(void * mem, size_t len) { int ret(0); if ((_useMAdvLimit <= len) && (static_cast<char *>(mem) + len) == getEnd()) { @@ -225,7 +240,8 @@ bool MmapMemory::freeTail(void * mem, size_t len) return (ret == 0); } -bool MmapMemory::reclaim(void * mem, size_t len) +bool +MmapMemory::reclaim(void * mem, size_t len) { int ret(0); if (_useMAdvLimit <= len) { diff --git a/vespamalloc/src/vespamalloc/util/osmem.h b/vespamalloc/src/vespamalloc/util/osmem.h index 4ccc2bc112c..2faf3c9b181 100644 --- a/vespamalloc/src/vespamalloc/util/osmem.h +++ b/vespamalloc/src/vespamalloc/util/osmem.h @@ -13,7 +13,7 @@ namespace vespamalloc { class Memory { public: - Memory(size_t blockSize) : _blockSize(std::max(blockSize, size_t(getpagesize()))), _start(NULL), _end(NULL) { } + Memory(size_t blockSize) : _blockSize(std::max(blockSize, size_t(getpagesize()))), _start(nullptr), _end(nullptr) { } virtual ~Memory() { } void * getStart() const { return _start; } void * getEnd() const { return _end; } diff --git a/vsm/src/vespa/vsm/vsm/docsumconfig.cpp b/vsm/src/vespa/vsm/vsm/docsumconfig.cpp index 7df2205bf39..25c13967c49 100644 --- a/vsm/src/vespa/vsm/vsm/docsumconfig.cpp +++ b/vsm/src/vespa/vsm/vsm/docsumconfig.cpp @@ -23,7 +23,8 @@ DynamicDocsumConfig::createFieldWriter(const string & fieldName, const string & fieldWriter.reset(new EmptyDFW()); rc = true; } else if ((overrideName == "attribute") || - ((overrideName == "geopos"))) { + (overrideName == "attributecombiner") || + (overrideName == "geopos")) { rc = true; } else { fieldWriter = search::docsummary::DynamicDocsumConfig::createFieldWriter(fieldName, overrideName, argument, rc); diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/Lock.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/Lock.java index d71660a990f..23fa3cccad2 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/Lock.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/Lock.java @@ -3,15 +3,13 @@ package com.yahoo.vespa.curator; import com.google.common.util.concurrent.UncheckedTimeoutException; import com.yahoo.transaction.Mutex; -import org.apache.curator.framework.CuratorFramework; import org.apache.curator.framework.recipes.locks.InterProcessLock; -import org.apache.curator.framework.recipes.locks.InterProcessMutex; import java.time.Duration; import java.util.concurrent.TimeUnit; /** - * A cluster-wide reentrant mutex which is released on (the last symmetric) close + * A cluster-wide re-entrant mutex which is released on (the last symmetric) close * * @author bratseth */ @@ -20,13 +18,6 @@ public class Lock implements Mutex { private final InterProcessLock mutex; private final String lockPath; - /** @deprecated pass a Curator instance instead */ - @Deprecated - public Lock(String lockPath, CuratorFramework curator) { - this.lockPath = lockPath; - mutex = new InterProcessMutex(curator, lockPath); - } - public Lock(String lockPath, Curator curator) { this.lockPath = lockPath; mutex = curator.createMutex(lockPath); diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCurator.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCurator.java index 67eedfc5dba..4013cf1d649 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCurator.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCurator.java @@ -14,7 +14,6 @@ import org.apache.curator.CuratorZookeeperClient; import org.apache.curator.framework.CuratorFramework; import org.apache.curator.framework.api.ACLBackgroundPathAndBytesable; import org.apache.curator.framework.api.ACLCreateModeBackgroundPathAndBytesable; -import org.apache.curator.framework.api.ACLCreateModePathAndBytesable; import org.apache.curator.framework.api.ACLPathAndBytesable; import org.apache.curator.framework.api.BackgroundCallback; import org.apache.curator.framework.api.BackgroundPathAndBytesable; @@ -26,8 +25,6 @@ import org.apache.curator.framework.api.CreateBuilder; import org.apache.curator.framework.api.CuratorListener; import org.apache.curator.framework.api.CuratorWatcher; import org.apache.curator.framework.api.DeleteBuilder; -import org.apache.curator.framework.api.ErrorListenerPathAndBytesable; -import org.apache.curator.framework.api.ErrorListenerPathable; import org.apache.curator.framework.api.ExistsBuilder; import org.apache.curator.framework.api.ExistsBuilderMain; import org.apache.curator.framework.api.GetACLBuilder; @@ -42,7 +39,6 @@ import org.apache.curator.framework.api.SetDataBackgroundVersionable; import org.apache.curator.framework.api.SetDataBuilder; import org.apache.curator.framework.api.SyncBuilder; import org.apache.curator.framework.api.UnhandledErrorListener; -import org.apache.curator.framework.api.VersionPathAndBytesable; import org.apache.curator.framework.api.WatchPathable; import org.apache.curator.framework.api.Watchable; import org.apache.curator.framework.api.transaction.CuratorTransaction; @@ -98,7 +94,7 @@ import static com.yahoo.vespa.curator.mock.MemoryFileSystem.Node; * Due to the "fluent API" style of Curator managing to break JavaDoc at a fundamental level, there is no * documentation on the contract of each method. The behavior here is deduced by observing what using code exists * and peeking at the Curator code. It may be incorrect in some corner cases.</p> - * + * * <p>Contains some code from PathUtils in ZooKeeper, licensed under the Apache 2.0 license.</p> * * @author bratseth @@ -628,6 +624,30 @@ public class MockCurator extends Curator { throw new UnsupportedOperationException("Not implemented in MockCurator"); } + public PathAndBytesable<T> inBackground() { + throw new UnsupportedOperationException("Not implemented in MockCurator"); + } + + public PathAndBytesable<T> inBackground(Object o) { + throw new UnsupportedOperationException("Not implemented in MockCurator"); + } + + public PathAndBytesable<T> inBackground(BackgroundCallback backgroundCallback) { + throw new UnsupportedOperationException("Not implemented in MockCurator"); + } + + public PathAndBytesable<T> inBackground(BackgroundCallback backgroundCallback, Object o) { + throw new UnsupportedOperationException("Not implemented in MockCurator"); + } + + public PathAndBytesable<T> inBackground(BackgroundCallback backgroundCallback, Executor executor) { + throw new UnsupportedOperationException("Not implemented in MockCurator"); + } + + public PathAndBytesable<T> inBackground(BackgroundCallback backgroundCallback, Object o, Executor executor) { + throw new UnsupportedOperationException("Not implemented in MockCurator"); + } + public ACLBackgroundPathAndBytesable<T> withMode(CreateMode createMode) { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @@ -695,71 +715,37 @@ public class MockCurator extends Curator { return createNode(s, bytes, createParents, createMode, fileSystem.root(), listeners); } - @Override - public ErrorListenerPathAndBytesable<String> inBackground() { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public ErrorListenerPathAndBytesable<String> inBackground(Object o) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public ErrorListenerPathAndBytesable<String> inBackground(BackgroundCallback backgroundCallback) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public ErrorListenerPathAndBytesable<String> inBackground(BackgroundCallback backgroundCallback, Object o) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public ErrorListenerPathAndBytesable<String> inBackground(BackgroundCallback backgroundCallback, Executor executor) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public ErrorListenerPathAndBytesable<String> inBackground(BackgroundCallback backgroundCallback, Object o, Executor executor) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } } private class MockBackgroundPathableBuilder<T> implements BackgroundPathable<T>, Watchable<BackgroundPathable<T>> { @Override - public ErrorListenerPathable<T> inBackground() { + public Pathable<T> inBackground() { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @Override - public ErrorListenerPathable<T> inBackground(Object o) { + public Pathable<T> inBackground(Object o) { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @Override - public ErrorListenerPathable<T> inBackground(BackgroundCallback backgroundCallback) { + public Pathable<T> inBackground(BackgroundCallback backgroundCallback) { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @Override - public ErrorListenerPathable<T> inBackground(BackgroundCallback backgroundCallback, Object o) { + public Pathable<T> inBackground(BackgroundCallback backgroundCallback, Object o) { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @Override - public ErrorListenerPathable<T> inBackground(BackgroundCallback backgroundCallback, Executor executor) { + public Pathable<T> inBackground(BackgroundCallback backgroundCallback, Executor executor) { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @Override - public ErrorListenerPathable<T> inBackground(BackgroundCallback backgroundCallback, Object o, Executor executor) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public T forPath(String s) throws Exception { + public Pathable<T> inBackground(BackgroundCallback backgroundCallback, Object o, Executor executor) { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @@ -777,6 +763,11 @@ public class MockCurator extends Curator { public BackgroundPathable<T> usingWatcher(CuratorWatcher curatorWatcher) { throw new UnsupportedOperationException("Not implemented in MockCurator"); } + + public T forPath(String path) throws Exception { + throw new UnsupportedOperationException("Not implemented in MockCurator"); + } + } private class MockGetChildrenBuilder extends MockBackgroundPathableBuilder<List<String>> implements GetChildrenBuilder { @@ -872,35 +863,6 @@ public class MockCurator extends Curator { return null; } - @Override - public ErrorListenerPathAndBytesable<Stat> inBackground() { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public ErrorListenerPathAndBytesable<Stat> inBackground(Object o) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public ErrorListenerPathAndBytesable<Stat> inBackground(BackgroundCallback backgroundCallback) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public ErrorListenerPathAndBytesable<Stat> inBackground(BackgroundCallback backgroundCallback, Object o) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public ErrorListenerPathAndBytesable<Stat> inBackground(BackgroundCallback backgroundCallback, Executor executor) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } - - @Override - public ErrorListenerPathAndBytesable<Stat> inBackground(BackgroundCallback backgroundCallback, Object o, Executor executor) { - throw new UnsupportedOperationException("Not implemented in MockCurator"); - } } /** Allows addition of directoryListeners which are never called */ @@ -979,7 +941,7 @@ public class MockCurator extends Curator { } @Override - public ACLCreateModePathAndBytesable<CuratorTransactionBridge> compressed() { + public ACLPathAndBytesable<CuratorTransactionBridge> compressed() { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @@ -1021,7 +983,7 @@ public class MockCurator extends Curator { private class MockTransactionSetDataBuilder implements TransactionSetDataBuilder { @Override - public VersionPathAndBytesable<CuratorTransactionBridge> compressed() { + public PathAndBytesable<CuratorTransactionBridge> compressed() { throw new UnsupportedOperationException("Not implemented in MockCurator"); } diff --git a/zkfacade/src/main/java/org/apache/curator/framework/api/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/api/package-info.java index b1b6c84838f..7fff06187ba 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/api/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/api/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator.framework.api; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/api/transaction/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/api/transaction/package-info.java index e525ef138ba..c96ddcc7f16 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/api/transaction/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/api/transaction/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator.framework.api.transaction; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/listen/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/listen/package-info.java index faa48e0c074..3777974a5d1 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/listen/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/listen/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator.framework.listen; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/package-info.java index aedbfeed15b..578dde579a5 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator.framework; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/recipes/atomic/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/recipes/atomic/package-info.java index db91e7396ea..eabcf404db1 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/recipes/atomic/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/recipes/atomic/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator.framework.recipes.atomic; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/recipes/barriers/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/recipes/barriers/package-info.java index 0349194c402..827b92e7a3b 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/recipes/barriers/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/recipes/barriers/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator.framework.recipes.barriers; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/recipes/cache/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/recipes/cache/package-info.java index 9e8dc0aa0ef..cd3d5d5cabe 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/recipes/cache/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/recipes/cache/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator.framework.recipes.cache; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/recipes/locks/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/recipes/locks/package-info.java index e7530a85539..915b60b9241 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/recipes/locks/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/recipes/locks/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator.framework.recipes.locks; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/state/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/state/package-info.java index 8ee09dc87be..4c353b8ba06 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/state/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/state/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator.framework.state; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/package-info.java b/zkfacade/src/main/java/org/apache/curator/package-info.java index 1f55cbc8a3e..736ec4b6f78 100644 --- a/zkfacade/src/main/java/org/apache/curator/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/retry/package-info.java b/zkfacade/src/main/java/org/apache/curator/retry/package-info.java index eb4a8c7ff5f..c931656f867 100644 --- a/zkfacade/src/main/java/org/apache/curator/retry/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/retry/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 2, minor = 12, micro = 0)) +@ExportPackage(version = @Version(major = 2, minor = 9, micro = 1)) package org.apache.curator.retry; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; |