diff options
421 files changed, 22089 insertions, 2656 deletions
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java index 5a509d77431..0b64f206267 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java @@ -13,9 +13,9 @@ import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; import com.yahoo.vespa.athenz.client.zts.Identity; import com.yahoo.vespa.athenz.client.zts.ZtsClient; import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; -import com.yahoo.vespa.athenz.tls.KeyStoreBuilder; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.athenz.tls.KeyUtils; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyUtils; import com.yahoo.vespa.athenz.utils.SiaUtils; import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.hosted.athenz.instanceproviderservice.config.AthenzProviderServiceConfig; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslTrustStoreConfigurator.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslTrustStoreConfigurator.java index a440f96cc49..da5fd430f1c 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslTrustStoreConfigurator.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslTrustStoreConfigurator.java @@ -4,11 +4,11 @@ package com.yahoo.vespa.hosted.athenz.instanceproviderservice; import com.google.inject.Inject; import com.yahoo.jdisc.http.ssl.SslTrustStoreConfigurator; import com.yahoo.jdisc.http.ssl.SslTrustStoreContext; -import com.yahoo.vespa.athenz.tls.KeyStoreBuilder; -import com.yahoo.vespa.athenz.tls.KeyStoreType; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; import com.yahoo.vespa.hosted.athenz.instanceproviderservice.config.AthenzProviderServiceConfig; -import java.io.File; +import java.nio.file.Paths; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.cert.X509Certificate; @@ -43,7 +43,7 @@ public class AthenzSslTrustStoreConfigurator implements SslTrustStoreConfigurato private static KeyStore createTrustStore(AthenzProviderServiceConfig athenzProviderServiceConfig) { try { return KeyStoreBuilder.withType(KeyStoreType.JKS) - .fromFile(new File(athenzProviderServiceConfig.athenzCaTrustStore())) + .fromFile(Paths.get(athenzProviderServiceConfig.athenzCaTrustStore())) .build(); } catch (Exception e) { throw new RuntimeException(e); diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java index 183a52f782c..40003d4ccf3 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java @@ -4,7 +4,7 @@ package com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl; import com.google.inject.Inject; import com.yahoo.config.provision.Zone; import com.yahoo.container.jdisc.secretstore.SecretStore; -import com.yahoo.vespa.athenz.tls.KeyUtils; +import com.yahoo.security.KeyUtils; import com.yahoo.vespa.hosted.athenz.instanceproviderservice.KeyProvider; import com.yahoo.vespa.hosted.athenz.instanceproviderservice.config.AthenzProviderServiceConfig; diff --git a/config-lib/src/main/java/com/yahoo/config/PathNode.java b/config-lib/src/main/java/com/yahoo/config/PathNode.java index b63dad4d1a7..9d73b5e23c2 100644 --- a/config-lib/src/main/java/com/yahoo/config/PathNode.java +++ b/config-lib/src/main/java/com/yahoo/config/PathNode.java @@ -14,7 +14,6 @@ import java.util.Map; * Represents a 'path' in a {@link ConfigInstance}, usually a filename. * * @author gjoranv - * @since 5.1.30 */ public class PathNode extends LeafNode<Path> { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 8c7398b3dde..afd33da369f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -13,12 +13,16 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; +import java.util.ArrayDeque; import java.util.Collection; import java.util.Collections; +import java.util.Deque; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Stack; +import java.util.stream.Collectors; /** * A context which only contains type information. @@ -26,21 +30,29 @@ import java.util.Optional; * query, attribute or constant features, as we do not have information about which such * features exist (but we know those that exist are doubles). * + * This is not multithread safe. + * * @author bratseth */ public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext<Reference> { private final Map<Reference, TensorType> featureTypes = new HashMap<>(); - public MapEvaluationTypeContext(Collection<ExpressionFunction> functions) { + /** For invocation loop detection */ + private final Deque<Reference> currentResolutionCallStack; + + MapEvaluationTypeContext(Collection<ExpressionFunction> functions) { super(functions); + this.currentResolutionCallStack = new ArrayDeque<>(); } - public MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, - Map<String, String> bindings, - Map<Reference, TensorType> featureTypes) { + private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, + Map<String, String> bindings, + Map<Reference, TensorType> featureTypes, + Deque<Reference> currentResolutionCallStack) { super(functions, bindings); this.featureTypes.putAll(featureTypes); + this.currentResolutionCallStack = currentResolutionCallStack; } public void setType(Reference reference, TensorType type) { @@ -54,6 +66,11 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement @Override public TensorType getType(Reference reference) { + if (currentResolutionCallStack.contains(reference)) + throw new IllegalArgumentException("Invocation loop: " + + currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) + + " -> " + reference); + // A reference to a macro argument? Optional<String> binding = boundIdentifier(reference); if (binding.isPresent()) { @@ -61,36 +78,42 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement // This is not pretty, but changing to bind expressions rather // than their string values requires deeper changes return new RankingExpression(binding.get()).type(this); - } - catch (ParseException e) { + } catch (ParseException e) { throw new IllegalArgumentException(e); } } - // A reference to an attribute, query or constant feature? - if (FeatureNames.isSimpleFeature(reference)) { - // The argument may be a local identifier bound to the actual value - String argument = reference.simpleArgument().get(); - reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument)); - return featureTypes.getOrDefault(reference, defaultTypeOf(reference)); - } + try { + currentResolutionCallStack.addLast(reference); - // A reference to a function? - Optional<ExpressionFunction> function = functionInvocation(reference); - if (function.isPresent()) { - return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); - } + // A reference to an attribute, query or constant feature? + if (FeatureNames.isSimpleFeature(reference)) { + // The argument may be a local identifier bound to the actual value + String argument = reference.simpleArgument().get(); + reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument)); + return featureTypes.getOrDefault(reference, defaultTypeOf(reference)); + } - // A reference to a feature which returns a tensor? - Optional<TensorType> featureTensorType = tensorFeatureType(reference); - if (featureTensorType.isPresent()) { - return featureTensorType.get(); - } + // A reference to a function? + Optional<ExpressionFunction> function = functionInvocation(reference); + if (function.isPresent()) { + return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); + } + + // A reference to a feature which returns a tensor? + Optional<TensorType> featureTensorType = tensorFeatureType(reference); + if (featureTensorType.isPresent()) { + return featureTensorType.get(); + } - // We do not know what this is - since we do not have complete knowledge abut the match features - // in Java we must assume this is a match feature and return the double type - which is the type of all - // all match features - return TensorType.empty; + // We do not know what this is - since we do not have complete knowledge abut the match features + // in Java we must assume this is a match feature and return the double type - which is the type of all + // all match features + return TensorType.empty; + } + finally { + currentResolutionCallStack.removeLast(); + } } /** @@ -173,7 +196,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement @Override public MapEvaluationTypeContext withBindings(Map<String, String> bindings) { if (bindings.isEmpty() && this.bindings.isEmpty()) return this; - return new MapEvaluationTypeContext(functions(), bindings, featureTypes); + return new MapEvaluationTypeContext(functions(), bindings, featureTypes, currentResolutionCallStack); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java index 164cb7f808e..5ac1418c0c7 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java @@ -1,6 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.config.FileReference; +import com.yahoo.vespa.model.AbstractService; +import com.yahoo.vespa.model.utils.FileSender; + +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -33,4 +38,14 @@ public class RankingConstants { return Collections.unmodifiableMap(constants); } + /** Initiate sending of these constants to some services over file distribution */ + public void sendTo(Collection<? extends AbstractService> services) { + for (RankingConstant constant : constants.values()) { + FileReference reference = (constant.getPathType() == RankingConstant.PathType.FILE) + ? FileSender.sendFileToServices(constant.getFileName(), services) + : FileSender.sendUriToServices(constant.getUri(), services); + constant.setFileReference(reference.value()); + } + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 4af26b72817..9a00ee5bbd0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -94,7 +94,7 @@ public class DerivedConfiguration { summaries = new Summaries(search, deployLogger); summaryMap = new SummaryMap(search, summaries); juniperrc = new Juniperrc(search); - rankProfileList = new RankProfileList(search, attributeFields, rankProfileRegistry, queryProfiles, importedModels); + rankProfileList = new RankProfileList(search, search.rankingConstants(), attributeFields, rankProfileRegistry, queryProfiles, importedModels); indexingScript = new IndexingScript(search); indexInfo = new IndexInfo(search); indexSchema = new IndexSchema(search); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 10881ab9ce0..fcbfb47c597 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -3,24 +3,35 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.model.AbstractService; + +import java.util.Collection; import java.util.Map; +import java.util.logging.Logger; /** * The derived rank profiles of a search definition * * @author bratseth */ -public class RankProfileList extends Derived implements RankProfilesConfig.Producer { +public class RankProfileList extends Derived implements RankProfilesConfig.Producer, RankingConstantsConfig.Producer { + + private static final Logger log = Logger.getLogger(RankProfileList.class.getName()); private final Map<String, RawRankProfile> rankProfiles = new java.util.LinkedHashMap<>(); + private final RankingConstants rankingConstants; public static RankProfileList empty = new RankProfileList(); private RankProfileList() { + this.rankingConstants = new RankingConstants(); } /** @@ -30,11 +41,13 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ * @param attributeFields the attribute fields to create a ranking for */ public RankProfileList(Search search, + RankingConstants rankingConstants, AttributeFields attributeFields, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, ImportedModels importedModels) { setName(search == null ? "default" : search.getName()); + this.rankingConstants = rankingConstants; deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields); } @@ -68,6 +81,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ return rankProfiles.get(name); } + public void sendConstantsTo(Collection<? extends AbstractService> services) { + rankingConstants.sendTo(services); + } + @Override public String getDerivedName() { return "rank-profiles"; } @@ -78,4 +95,17 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } + @Override + public void getConfig(RankingConstantsConfig.Builder builder) { + for (RankingConstant constant : rankingConstants.asMap().values()) { + if ("".equals(constant.getFileReference())) + log.warning("Illegal file reference " + constant); // Let tests pass ... we should find a better way + else + builder.constant(new RankingConstantsConfig.Constant.Builder() + .name(constant.getName()) + .fileref(constant.getFileReference()) + .type(constant.getType())); + } + } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 3e9d188670e..1b15233fead 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -169,6 +169,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri deployState.rankProfileRegistry(), deployState.getQueryProfiles().getRegistry()); this.rankProfileList = new RankProfileList(null, // null search -> global + rankingConstants, AttributeFields.empty, deployState.rankProfileRegistry(), deployState.getQueryProfiles().getRegistry(), diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java index e8985b094ac..73d77406700 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java @@ -135,6 +135,8 @@ public class VespaMetricSet { metrics.add(new Metric("http.status.3xx.rate")); metrics.add(new Metric("http.status.4xx.rate")); metrics.add(new Metric("http.status.5xx.rate")); + metrics.add(new Metric("http.status.401.rate")); + metrics.add(new Metric("http.status.403.rate")); metrics.add(new Metric("jdisc.http.request.uri_length.average")); metrics.add(new Metric("jdisc.http.request.uri_length.max")); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java index e34d490afe1..095b27a0904 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java @@ -1,18 +1,22 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.builder.xml.dom; -import com.yahoo.component.Version; import com.yahoo.config.model.ConfigModelContext; import com.yahoo.config.model.api.ConfigServerSpec; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.config.provision.SystemName; +import com.yahoo.log.LogLevel; +import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.vespa.model.HostResource; import com.yahoo.vespa.model.HostSystem; import com.yahoo.vespa.model.admin.Admin; import com.yahoo.vespa.model.admin.Logserver; import com.yahoo.vespa.model.admin.Slobrok; import com.yahoo.vespa.model.container.Container; +import com.yahoo.vespa.model.container.ContainerCluster; import com.yahoo.vespa.model.container.ContainerModel; +import com.yahoo.vespa.model.container.component.Handler; import org.w3c.dom.Element; import java.util.ArrayList; @@ -73,14 +77,46 @@ public class DomAdminV4Builder extends DomAdminBuilderBase { if (nodesSpecification.count() > 1) throw new IllegalArgumentException("You can only request a single log server"); if (nodesSpecification.isDedicated()) { - createLogserver(admin, allocateHosts(admin.getHostSystem(), "logserver", nodesSpecification)); - } - else { - if (containerModels.iterator().hasNext()) - createLogserver(admin, sortedContainerHostsFrom(containerModels.iterator().next(), nodesSpecification.count(), false)); + Collection<HostResource> hosts = allocateHosts(admin.getHostSystem(), "logserver", nodesSpecification); + if (hosts.isEmpty()) return; // No log server can be created (and none is needed) + + Logserver logserver = createLogserver(admin, hosts); + // TODO: Enable for main system as well + if (context.getDeployState().isHosted() && context.getDeployState().zone().system() == SystemName.cd) + createAdditionalContainerOnLogserverHost(admin, logserver.getHostResource()); + } else if (containerModels.iterator().hasNext()) { + List<HostResource> hosts = sortedContainerHostsFrom(containerModels.iterator().next(), nodesSpecification.count(), false); + if (hosts.isEmpty()) return; // No log server can be created (and none is needed) + + createLogserver(admin, hosts); + } else { + context.getDeployLogger().log(LogLevel.INFO, "No container host available to use for running logserver"); } } + // Creates a container cluster 'logserver-cluster' with 1 container on logserver host + // for setting up a handler for getting logs from logserver + private void createAdditionalContainerOnLogserverHost(Admin admin, HostResource hostResource) { + ContainerCluster logServerCluster = new ContainerCluster(admin, "logserver-cluster", "logserver-cluster", RankProfileList.empty); + ContainerModel logserverClusterModel = new ContainerModel(context.withParent(admin).withId(logServerCluster.getSubId())); + logserverClusterModel.setCluster(logServerCluster); + + addLogHandler(logServerCluster); + + Container container = new Container(logServerCluster, "logserver-container", 0); + container.setHostResource(hostResource); + container.initService(); + logServerCluster.addContainer(container); + admin.addAndInitializeService(hostResource, container); + } + + // TODO: Wire in handler for getting logs + private void addLogHandler(ContainerCluster cluster) { + Handler<?> logHandler = Handler.fromClassName("TODO"); + //logHandler.addServerBindings("http://*/logs/", "https://*/logs/"); + cluster.addComponent(logHandler); + } + private Collection<HostResource> allocateHosts(HostSystem hostSystem, String clusterId, NodesSpecification nodesSpecification) { return nodesSpecification.provision(hostSystem, ClusterSpec.Type.admin, @@ -148,12 +184,12 @@ public class DomAdminV4Builder extends DomAdminBuilderBase { return HostResource.pickHosts(hosts, count, 1); } - private void createLogserver(Admin admin, Collection<HostResource> hosts) { - if (hosts.isEmpty()) return; // No log server can be created (and none is needed) + private Logserver createLogserver(Admin admin, Collection<HostResource> hosts) { Logserver logserver = new Logserver(admin); logserver.setHostResource(hosts.iterator().next()); admin.setLogserver(logserver); logserver.initService(); + return logserver; } private void createSlobroks(Admin admin, Collection<HostResource> hosts) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java index 8c6c13d810f..fbe86d26b02 100755 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java @@ -41,10 +41,9 @@ import com.yahoo.search.config.IndexInfoConfig; import com.yahoo.search.config.QrStartConfig; import com.yahoo.search.pagetemplates.PageTemplatesConfig; import com.yahoo.search.query.profile.config.QueryProfilesConfig; -import com.yahoo.searchdefinition.RankProfileRegistry; -import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.model.PortsMeta; import com.yahoo.vespa.model.Service; @@ -66,11 +65,9 @@ import com.yahoo.vespa.model.container.docproc.ContainerDocproc; import com.yahoo.vespa.model.container.docproc.DocprocChains; import com.yahoo.vespa.model.container.http.Http; import com.yahoo.vespa.model.container.jersey.Jersey2Servlet; -import com.yahoo.vespa.model.container.jersey.JerseyHandler; import com.yahoo.vespa.model.container.jersey.RestApi; import com.yahoo.vespa.model.container.processing.ProcessingChains; import com.yahoo.vespa.model.container.search.ContainerSearch; -import com.yahoo.vespa.model.container.search.QueryProfiles; import com.yahoo.vespa.model.container.search.searchchain.SearchChains; import com.yahoo.vespa.model.content.Content; import com.yahoo.vespa.model.search.AbstractSearchCluster; @@ -79,7 +76,6 @@ import com.yahoo.vespaclient.config.FeederConfig; import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; - import java.nio.file.Path; import java.util.ArrayList; import java.util.Collection; @@ -91,7 +87,6 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -129,7 +124,8 @@ public final class ContainerCluster RoutingProviderConfig.Producer, ConfigserverConfig.Producer, ThreadpoolConfig.Producer, - RankProfilesConfig.Producer + RankProfilesConfig.Producer, + RankingConstantsConfig.Producer { @@ -364,6 +360,7 @@ public final class ContainerCluster public void prepare() { addAndSendApplicationBundles(); + rankProfileList.sendConstantsTo(containers); sendUserConfiguredFiles(); setApplicationMetaData(); for (RestApi restApi : restApiGroup.getComponents()) @@ -562,10 +559,6 @@ public final class ContainerCluster @Override public final void getConfig(JdiscBindingsConfig.Builder builder) { builder.handlers.putAll(DiscBindingsConfigGenerator.generate(getHandlers())); - - allJersey1Handlers().forEach(handler -> - builder.handlers.putAll(DiscBindingsConfigGenerator.generate(handler)) - ); } @Override @@ -573,10 +566,6 @@ public final class ContainerCluster clusterVerifier.getConfig(builder); } - private Stream<JerseyHandler> allJersey1Handlers() { - return restApiGroup.getComponents().stream().flatMap(streamOf(RestApi::getJersey1Handler)); - } - @Override public void getConfig(ServletPathsConfig.Builder builder) { allServlets().forEach(servlet -> @@ -591,14 +580,7 @@ public final class ContainerCluster } private Stream<Jersey2Servlet> allJersey2Servlets() { - return restApiGroup.getComponents().stream().flatMap(streamOf(RestApi::getJersey2Servlet)); - } - - private <T, R> Function<T, Stream<R>> streamOf(Function<T, Optional<R>> f) { - return t -> - f.apply(t). - <Stream<R>>map(Stream::of). - orElse(Stream.empty()); + return restApiGroup.getComponents().stream().map(RestApi::getJersey2Servlet); } @Override @@ -732,6 +714,9 @@ public final class ContainerCluster rankProfileList.getConfig(builder); } + @Override + public void getConfig(RankingConstantsConfig.Builder builder) { rankProfileList.getConfig(builder); } + public void setMbusParams(MbusParams mbusParams) { this.mbusParams = mbusParams; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/JerseyHandler.java b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/JerseyHandler.java deleted file mode 100644 index 737882b703d..00000000000 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/JerseyHandler.java +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.model.container.jersey; - -import com.yahoo.config.model.producer.AbstractConfigProducer; -import com.yahoo.container.bundle.BundleInstantiationSpecification; -import com.yahoo.osgi.provider.model.ComponentModel; -import com.yahoo.vespa.model.container.component.Handler; - -/** - * @author gjoranv - * @since 5.6 - */ -public class JerseyHandler extends Handler<AbstractConfigProducer<?>> { - - public static final String BUNDLE = "container-jersey"; - public static final String CLASS = "com.yahoo.container.jdisc.jersey.JerseyHandler"; - - public JerseyHandler(String bindingPath) { - super(new ComponentModel(bundleSpec(CLASS, BUNDLE, bindingPath))); - } - - public static BundleInstantiationSpecification bundleSpec(String className, String bundle, String bindingPath) { - return BundleInstantiationSpecification.getFromStrings( - className + "-" + RestApi.idFromPath(bindingPath), - className, - bundle); - } -} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApi.java b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApi.java index 63825aa2a1b..be8209bcc4e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApi.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApi.java @@ -2,32 +2,24 @@ package com.yahoo.vespa.model.container.jersey; import com.yahoo.config.model.producer.AbstractConfigProducer; -import com.yahoo.container.config.jersey.JerseyInitConfig; -import com.yahoo.vespa.model.container.component.Component; - -import java.util.Optional; /** + * Represents a rest-api + * * @author gjoranv - * @since 5.6 */ -public class RestApi extends AbstractConfigProducer<AbstractConfigProducer<?>> implements - JerseyInitConfig.Producer -{ - public final boolean isJersey2; +public class RestApi extends AbstractConfigProducer<AbstractConfigProducer<?>> { + private final String bindingPath; - private final Component<?, ?> jerseyHandler; + private final Jersey2Servlet jerseyServlet; private RestApiContext restApiContext; - public RestApi(String bindingPath, boolean isJersey2) { + public RestApi(String bindingPath) { super(idFromPath(bindingPath)); this.bindingPath = bindingPath; - this.isJersey2 = isJersey2; - jerseyHandler = isJersey2 ? - createJersey2Servlet(this.bindingPath): - createJersey1Handler(this.bindingPath); - addChild(jerseyHandler); + jerseyServlet = createJersey2Servlet(this.bindingPath); + addChild(jerseyServlet); } public static String idFromPath(String path) { @@ -38,44 +30,20 @@ public class RestApi extends AbstractConfigProducer<AbstractConfigProducer<?>> i return new Jersey2Servlet(bindingPath); } - private static JerseyHandler createJersey1Handler(String bindingPath) { - JerseyHandler jerseyHandler = new JerseyHandler(bindingPath); - jerseyHandler.addServerBindings(getBindings(bindingPath)); - return jerseyHandler; - } - public String getBindingPath() { return bindingPath; } - @Override - public void getConfig(JerseyInitConfig.Builder builder) { - builder.jerseyMapping(bindingPath); - } - public void setRestApiContext(RestApiContext restApiContext) { this.restApiContext = restApiContext; addChild(restApiContext); - jerseyHandler.inject(restApiContext); + jerseyServlet.inject(restApiContext); } public RestApiContext getContext() { return restApiContext; } - public Optional<JerseyHandler> getJersey1Handler() { - return isJersey2 ? - Optional.empty(): - Optional.of((JerseyHandler)jerseyHandler); - } - - public Optional<Jersey2Servlet> getJersey2Servlet() { - return isJersey2 ? - Optional.of((Jersey2Servlet)jerseyHandler) : - Optional.empty(); - } - - private static String[] getBindings(String bindingPath) { - String bindingWithoutScheme = "://*/" + bindingPath + "/*"; - return new String[] {"http" + bindingWithoutScheme, "https" + bindingWithoutScheme}; + public Jersey2Servlet getJersey2Servlet() { + return jerseyServlet; } public void prepare() { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApiContext.java b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApiContext.java index 5e48a1b1951..7fce9d2b636 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApiContext.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/RestApiContext.java @@ -22,7 +22,6 @@ import java.util.logging.Logger; /** * @author gjoranv - * @since 5.16 */ public class RestApiContext extends SimpleComponent implements JerseyBundlesConfig.Producer, @@ -87,10 +86,6 @@ public class RestApiContext extends SimpleComponent implements } } - public void addInjections(Map<String, String> injections) { - injectComponentForClass.putAll(injections); - } - @Override public void validate() throws Exception { super.validate(); @@ -117,7 +112,6 @@ public class RestApiContext extends SimpleComponent implements private Predicate<Component> isCycleGeneratingComponent = component -> { switch (component.getClassId().getName()) { case CONTAINER_CLASS: - case JerseyHandler.CLASS: case Jersey2Servlet.CLASS: case "com.yahoo.jdisc.http.server.jetty.JettyHttpServer": case "com.yahoo.container.handler.observability.ApplicationStatusHandler": diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/xml/RestApiBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/xml/RestApiBuilder.java index 245db3c014f..6728f0be29f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/xml/RestApiBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/jersey/xml/RestApiBuilder.java @@ -13,8 +13,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalAttribute; - /** * @author gjoranv * @since 5.6 @@ -24,8 +22,7 @@ public class RestApiBuilder extends VespaDomBuilder.DomConfigProducerBuilder<Res @Override protected RestApi doBuild(AbstractConfigProducer ancestor, Element spec) { String bindingPath = spec.getAttribute("path"); - boolean jersey2 = Boolean.parseBoolean(getOptionalAttribute(spec, "jersey2").orElse("false")); - RestApi restApi = new RestApi(bindingPath, jersey2); + RestApi restApi = new RestApi(bindingPath); restApi.setRestApiContext( createRestApiContext(ancestor, spec, bindingPath)); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java index 19d014e0a1d..ceb48732116 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java @@ -15,6 +15,7 @@ import java.util.*; /** * Config producer for the FederationSearcher. + * * @author Tony Vaagenes */ public class FederationSearcher extends Searcher<FederationSearcherModel> implements FederationConfig.Producer { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/routing/Protocol.java b/config-model/src/main/java/com/yahoo/vespa/model/routing/Protocol.java index ad684894176..49596aa0ddf 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/routing/Protocol.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/routing/Protocol.java @@ -12,18 +12,10 @@ import com.yahoo.messagebus.routing.RoutingTableSpec; */ public interface Protocol { - /** - * Returns the specification for the routing table of this protocol. - * - * @return The routing table spec. - */ - public RoutingTableSpec getRoutingTableSpec(); + /** Returns the specification for the routing table of this protocol. */ + RoutingTableSpec getRoutingTableSpec(); - /** - * Returns the specification of the application as seen by this protocol. - * - * @return The application spec. - */ - public ApplicationSpec getApplicationSpec(); + /** Returns the specification of the application as seen by this protocol. */ + ApplicationSpec getApplicationSpec(); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/routing/Routing.java b/config-model/src/main/java/com/yahoo/vespa/model/routing/Routing.java index 16f51935f2a..2403594d331 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/routing/Routing.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/routing/Routing.java @@ -21,7 +21,7 @@ public class Routing extends ConfigModel { private final List<String> errors = new ArrayList<>(); private ApplicationSpec explicitApplication = null; private RoutingSpec explicitRouting = null; - private List<Protocol> protocols = new ArrayList<>(); + private final List<Protocol> protocols = new ArrayList<>(); private RoutingSpec derivedRouting; public Routing(ConfigModelContext modelContext) { @@ -91,7 +91,7 @@ public class Routing extends ConfigModel { } public void getConfig(MessagebusConfig.Builder builder) { - if (derivedRouting==null) { + if (derivedRouting == null) { // The error list should be populated then return; } @@ -198,4 +198,5 @@ public class Routing extends ConfigModel { public List<String> getErrors() { return Collections.unmodifiableList(errors); } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java index a6bf51a2503..b29ed0fc25b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java @@ -75,16 +75,7 @@ public class DocumentDatabase extends AbstractConfigProducer implements @Override public void getConfig(RankingConstantsConfig.Builder builder) { - for (RankingConstant constant : derivedCfg.getSearch().rankingConstants().asMap().values()) { - if ("".equals(constant.getFileReference())) { - System.err.println("INVALID rank constant "+constant.getName()+" [missing file reference]"); // TODO: Throw or log warning - continue; - } - builder.constant(new RankingConstantsConfig.Constant.Builder() - .name(constant.getName()) - .fileref(constant.getFileReference()) - .type(constant.getType())); - } + derivedCfg.getRankProfileList().getConfig(builder); } @Override diff --git a/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax.onnx b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax.onnx Binary files differnew file mode 100644 index 00000000000..a86019bf53a --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax.onnx diff --git a/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/saved_model.pbtxt b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/saved_model.pbtxt new file mode 100644 index 00000000000..05b0e4e0f29 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/saved_model.pbtxt @@ -0,0 +1,5039 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "ApplyGradientDescent" + input_arg { + name: "var" + type_attr: "T" + is_ref: true + } + input_arg { + name: "alpha" + type_attr: "T" + } + input_arg { + name: "delta" + type_attr: "T" + } + output_arg { + name: "out" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } + } + op { + name: "ArgMax" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dimension" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "output_type" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "output_type" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Assign" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + attr { + name: "validate_shape" + type: "bool" + default_value { + b: true + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: true + } + } + allows_uninitialized_input: true + } + op { + name: "BroadcastGradientArgs" + input_arg { + name: "s0" + type_attr: "T" + } + input_arg { + name: "s1" + type_attr: "T" + } + output_arg { + name: "r0" + type_attr: "T" + } + output_arg { + name: "r1" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Cast" + input_arg { + name: "x" + type_attr: "SrcT" + } + output_arg { + name: "y" + type_attr: "DstT" + } + attr { + name: "SrcT" + type: "type" + } + attr { + name: "DstT" + type: "type" + } + } + op { + name: "ConcatV2" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + input_arg { + name: "axis" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 2 + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "Equal" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type: DT_BOOL + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_QUINT8 + type: DT_QINT8 + type: DT_QINT32 + type: DT_STRING + type: DT_BOOL + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "ExpandDims" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dim" + type_attr: "Tdim" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tdim" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Fill" + input_arg { + name: "dims" + type: DT_INT32 + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "FloorDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { + name: "Mean" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + is_stateful: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "Prod" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RealDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Reshape" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "shape" + type_attr: "Tshape" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshape" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "Shape" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "Slice" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "begin" + type_attr: "Index" + } + input_arg { + name: "size" + type_attr: "Index" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Index" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "SoftmaxCrossEntropyWithLogits" + input_arg { + name: "features" + type_attr: "T" + } + input_arg { + name: "labels" + type_attr: "T" + } + output_arg { + name: "loss" + type_attr: "T" + } + output_arg { + name: "backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sum" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Tile" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "multiples" + type_attr: "Tmultiples" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tmultiples" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "VariableV2" + output_arg { + name: "ref" + type_attr: "dtype" + is_ref: true + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + } + op { + name: "ZerosLike" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + } + tags: "serve" + tensorflow_version: "1.4.1" + tensorflow_git_version: "v1.4.0-19-ga52c8d9" + } + graph_def { + node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + node { + name: "layer/zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "layer/Variable" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "layer/Variable/Assign" + op: "Assign" + input: "layer/Variable" + input: "layer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "layer/Variable/read" + op: "Identity" + input: "layer/Variable" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "layer/zeros_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "layer/Variable_1" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "layer/Variable_1/Assign" + op: "Assign" + input: "layer/Variable_1" + input: "layer/zeros_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "layer/Variable_1/read" + op: "Identity" + input: "layer/Variable_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "layer/MatMul" + op: "MatMul" + input: "Placeholder" + input: "layer/Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "layer/add" + op: "Add" + input: "layer/MatMul" + input: "layer/Variable_1/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "Rank" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape" + op: "Shape" + input: "layer/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Rank_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_1" + op: "Shape" + input: "layer/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape" + op: "Reshape" + input: "layer/add" + input: "concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Rank_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_2" + op: "Shape" + input: "Placeholder_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub_1/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_1/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat_1/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat_1/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape_1" + op: "Reshape" + input: "Placeholder_1" + input: "concat_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape" + input: "Reshape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Sub_2/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_2/begin" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Reshape_2" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean" + op: "Mean" + input: "Reshape_2" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape_1" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Shape_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_2" + input: "gradients/Mean_grad/Const_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_2_grad/Shape" + op: "Shape" + input: "SoftmaxCrossEntropyWithLogits" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_2_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_2_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/zeros_like" + op: "ZerosLike" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_2_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_grad/Shape" + op: "Shape" + input: "layer/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/Shape" + op: "Shape" + input: "layer/MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/layer/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } + } + node { + name: "gradients/layer/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/layer/add_grad/Shape" + input: "gradients/layer/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/Sum" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/layer/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/layer/add_grad/Reshape" + op: "Reshape" + input: "gradients/layer/add_grad/Sum" + input: "gradients/layer/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/layer/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/layer/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/layer/add_grad/Sum_1" + input: "gradients/layer/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/layer/add_grad/Reshape" + input: "^gradients/layer/add_grad/Reshape_1" + } + node { + name: "gradients/layer/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/layer/add_grad/Reshape" + input: "^gradients/layer/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/layer/add_grad/Reshape_1" + input: "^gradients/layer/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/layer/add_grad/tuple/control_dependency" + input: "layer/Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "gradients/layer/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Placeholder" + input: "gradients/layer/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "gradients/layer/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/layer/MatMul_grad/MatMul" + input: "^gradients/layer/MatMul_grad/MatMul_1" + } + node { + name: "gradients/layer/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/layer/MatMul_grad/MatMul" + input: "^gradients/layer/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "gradients/layer/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/layer/MatMul_grad/MatMul_1" + input: "^gradients/layer/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "GradientDescent/learning_rate" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node { + name: "GradientDescent/update_layer/Variable/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "layer/Variable" + input: "GradientDescent/learning_rate" + input: "gradients/layer/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent/update_layer/Variable_1/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "layer/Variable_1" + input: "GradientDescent/learning_rate" + input: "gradients/layer/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_layer/Variable/ApplyGradientDescent" + input: "^GradientDescent/update_layer/Variable_1/ApplyGradientDescent" + } + node { + name: "init" + op: "NoOp" + input: "^layer/Variable/Assign" + input: "^layer/Variable_1/Assign" + } + node { + name: "ArgMax/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax" + op: "ArgMax" + input: "layer/add" + input: "ArgMax/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "ArgMax_1/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax_1" + op: "ArgMax" + input: "Placeholder_1" + input: "ArgMax_1/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "Equal" + op: "Equal" + input: "ArgMax" + input: "ArgMax_1" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Cast_1" + op: "Cast" + input: "Equal" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean_1" + op: "Mean" + input: "Cast_1" + input: "Const_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "save/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "model" + } + } + } + } + node { + name: "save/StringJoin/inputs_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp_65caff16d5244276b9828b0dab21b157/part" + } + } + } + } + node { + name: "save/StringJoin" + op: "StringJoin" + input: "save/Const" + input: "save/StringJoin/inputs_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "separator" + value { + s: "" + } + } + } + node { + name: "save/num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "save/ShardedFilename/shard" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "save/ShardedFilename" + op: "ShardedFilename" + input: "save/StringJoin" + input: "save/ShardedFilename/shard" + input: "save/num_shards" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/SaveV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "layer/Variable" + string_val: "layer/Variable_1" + } + } + } + } + node { + name: "save/SaveV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "" + string_val: "" + } + } + } + } + node { + name: "save/SaveV2" + op: "SaveV2" + input: "save/ShardedFilename" + input: "save/SaveV2/tensor_names" + input: "save/SaveV2/shape_and_slices" + input: "layer/Variable" + input: "layer/Variable_1" + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + } + node { + name: "save/control_dependency" + op: "Identity" + input: "save/ShardedFilename" + input: "^save/SaveV2" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@save/ShardedFilename" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "save/ShardedFilename" + input: "^save/control_dependency" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "save/MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "save/MergeV2Checkpoints/checkpoint_prefixes" + input: "save/Const" + attr { + key: "delete_old_dirs" + value { + b: true + } + } + } + node { + name: "save/Identity" + op: "Identity" + input: "save/Const" + input: "^save/control_dependency" + input: "^save/MergeV2Checkpoints" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/RestoreV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "layer/Variable" + } + } + } + } + node { + name: "save/RestoreV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2/tensor_names" + input: "save/RestoreV2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign" + op: "Assign" + input: "layer/Variable" + input: "save/RestoreV2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_1/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "layer/Variable_1" + } + } + } + } + node { + name: "save/RestoreV2_1/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_1" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_1/tensor_names" + input: "save/RestoreV2_1/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_1" + op: "Assign" + input: "layer/Variable_1" + input: "save/RestoreV2_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/restore_shard" + op: "NoOp" + input: "^save/Assign" + input: "^save/Assign_1" + } + node { + name: "save/restore_all" + op: "NoOp" + input: "^save/restore_shard" + } + versions { + producer: 24 + } + } + saver_def { + filename_tensor_name: "save/Const:0" + save_tensor_name: "save/Identity:0" + restore_op_name: "save/restore_all" + max_to_keep: 5 + sharded: true + keep_checkpoint_every_n_hours: 10000.0 + version: V2 + } + collection_def { + key: "train_op" + value { + node_list { + value: "GradientDescent" + } + } + } + collection_def { + key: "trainable_variables" + value { + bytes_list { + value: "\n\020layer/Variable:0\022\025layer/Variable/Assign\032\025layer/Variable/read:02\rlayer/zeros:0" + value: "\n\022layer/Variable_1:0\022\027layer/Variable_1/Assign\032\027layer/Variable_1/read:02\017layer/zeros_1:0" + } + } + } + collection_def { + key: "variables" + value { + bytes_list { + value: "\n\020layer/Variable:0\022\025layer/Variable/Assign\032\025layer/Variable/read:02\rlayer/zeros:0" + value: "\n\022layer/Variable_1:0\022\027layer/Variable_1/Assign\032\027layer/Variable_1/read:02\017layer/zeros_1:0" + } + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "x" + value { + name: "Placeholder:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + outputs { + key: "y" + value { + name: "layer/add:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.data-00000-of-00001 b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.data-00000-of-00001 Binary files differnew file mode 100644 index 00000000000..826b0280abf --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.data-00000-of-00001 diff --git a/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.index b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.index Binary files differnew file mode 100644 index 00000000000..d00fc5b06ed --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_name_collision/models/parent/mnist_softmax/variables/variables.index diff --git a/config-model/src/test/cfg/application/ml_serving_name_collision/services.xml b/config-model/src/test/cfg/application/ml_serving_name_collision/services.xml new file mode 100644 index 00000000000..42528336bc5 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving_name_collision/services.xml @@ -0,0 +1,12 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <container version="1.0"> + <nodes> + <node hostalias="node1" /> + </nodes> + + </container> + +</services> diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java index c5fb4f575cf..ad2f62b7dc3 100644 --- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java @@ -1,11 +1,19 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.config.model; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; +import ai.vespa.models.evaluation.RankProfilesConfigImporter; +import com.yahoo.config.FileReference; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ContainerCluster; import org.junit.After; @@ -56,37 +64,67 @@ public class ModelEvaluationTest { private void assertHasMlModels(VespaModel model) { ContainerCluster cluster = model.getContainerClusters().get("container"); + RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); cluster.getConfig(b); RankProfilesConfig config = new RankProfilesConfig(b); + + RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder(); + cluster.getConfig(cb); + RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb); + assertEquals(4, config.rankprofile().size()); Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); assertTrue(modelNames.contains("xgboost_2_2")); + assertTrue(modelNames.contains("mnist_saved")); assertTrue(modelNames.contains("mnist_softmax")); assertTrue(modelNames.contains("mnist_softmax_saved")); - ModelsEvaluator evaluator = new ModelsEvaluator(config); + ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null)) + .importFrom(config, constantsConfig)); assertEquals(4, evaluator.models().size()); + Model xgboost = evaluator.models().get("xgboost_2_2"); assertNotNull(xgboost); assertNotNull(xgboost.evaluatorOf()); assertNotNull(xgboost.evaluatorOf("xgboost_2_2")); - Model onnx = evaluator.models().get("mnist_softmax"); - assertNotNull(onnx); - assertNotNull(onnx.evaluatorOf()); - assertNotNull(onnx.evaluatorOf("default")); - assertNotNull(onnx.evaluatorOf("default", "add")); - assertNotNull(onnx.evaluatorOf("default.add")); + Model tensorflow_mnist = evaluator.models().get("mnist_saved"); + assertNotNull(tensorflow_mnist); + assertNotNull(tensorflow_mnist.evaluatorOf("serving_default")); + assertNotNull(tensorflow_mnist.evaluatorOf("serving_default", "y")); + assertNotNull(tensorflow_mnist.evaluatorOf("serving_default.y")); + assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default.y")); + assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default", "y")); + + Model onnx_mnist_softmax = evaluator.models().get("mnist_softmax"); + assertNotNull(onnx_mnist_softmax); + assertNotNull(onnx_mnist_softmax.evaluatorOf()); + assertNotNull(onnx_mnist_softmax.evaluatorOf("default")); + assertNotNull(onnx_mnist_softmax.evaluatorOf("default", "add")); + assertNotNull(onnx_mnist_softmax.evaluatorOf("default.add")); assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add")); assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add")); - Model tensorflow = evaluator.models().get("mnist_softmax_saved"); - assertNotNull(tensorflow); - assertNotNull(tensorflow.evaluatorOf()); - assertNotNull(tensorflow.evaluatorOf("serving_default")); - assertNotNull(tensorflow.evaluatorOf("serving_default", "y")); + Model tensorflow_mnist_softmax = evaluator.models().get("mnist_softmax_saved"); + assertNotNull(tensorflow_mnist_softmax); + assertNotNull(tensorflow_mnist_softmax.evaluatorOf()); + assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default")); + assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default", "y")); + } + + // We don't have function file distribution so just return empty tensor constants + private static class ToleratingMissingConstantFilesRankProfilesConfigImporter extends RankProfilesConfigImporter { + + public ToleratingMissingConstantFilesRankProfilesConfigImporter(FileAcquirer fileAcquirer) { + super(fileAcquirer); + } + + protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) { + return Tensor.from(type, "{}"); + } + } } diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelNameCollisionTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelNameCollisionTest.java new file mode 100644 index 00000000000..08f18331d1c --- /dev/null +++ b/config-model/src/test/java/com/yahoo/config/model/ModelNameCollisionTest.java @@ -0,0 +1,43 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.config.model; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.vespa.model.VespaModel; +import org.junit.After; +import org.junit.Test; +import org.xml.sax.SAXException; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class ModelNameCollisionTest { + + private static final Path appDir = Path.fromString("src/test/cfg/application/ml_serving_name_collision"); + + @After + public void removeGeneratedModelFiles() { + IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + + @Test + public void testMl_ServingApplication() throws SAXException, IOException { + ApplicationPackageTester tester = ApplicationPackageTester.create(appDir.toString()); + try { + new VespaModel(tester.app()); + } + catch (IllegalArgumentException e) { + assertEquals("The models in " + + appDir + "/models/parent/mnist_softmax.onnx and " + + appDir + "/models/parent/mnist_softmax" + + " both resolve to the model name 'parent_mnist_softmax'", + e.getMessage()); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java index aa01070d296..056fc27f067 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java @@ -1,8 +1,4 @@ -/* - * // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - * - * - */ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java index e1ddd0c02ca..b13ffabda77 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java @@ -7,12 +7,14 @@ import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.yolean.Exceptions; import org.junit.Test; import java.util.Optional; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * @author bratseth diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java new file mode 100644 index 00000000000..df9a40d29e2 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java @@ -0,0 +1,197 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition; + +import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class RankingExpressionLoopDetectionTestCase { + + @Test + public void testSelfLoop() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field a type string { \n" + + " indexing: index \n" + + " }\n" + + " }\n" + + " \n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo\n" + + " }\n" + + " macro foo() {\n" + + " expression: foo\n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + try { + builder.build(); + fail("Excepted exception"); + } + catch (IllegalArgumentException e) { + assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> foo", + Exceptions.toMessageString(e)); + } + } + + @Test + public void testNestedLoop() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field a type string { \n" + + " indexing: index \n" + + " }\n" + + " }\n" + + " \n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo\n" + + " }\n" + + " macro foo() {\n" + + " expression: arg(5)\n" + + " }\n" + + " macro arg(a1) {\n" + + " expression: foo + a1*2\n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + try { + builder.build(); + fail("Excepted exception"); + } + catch (IllegalArgumentException e) { + assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> arg(5) -> foo", + Exceptions.toMessageString(e)); + } + } + + @Test + public void testSelfArgumentLoop() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field a type string { \n" + + " indexing: index \n" + + " }\n" + + " }\n" + + " \n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo\n" + + " }\n" + + " macro foo() {\n" + + " expression: arg(foo)\n" + + " }\n" + + " macro arg(a1) {\n" + + " expression: a1*2\n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + try { + builder.build(); + fail("Excepted exception"); + } + catch (IllegalArgumentException e) { + assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> arg(foo) -> foo", + Exceptions.toMessageString(e)); + } + } + + @Test + public void testNoLoopWithSameLocalArgument() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field a type string { \n" + + " indexing: index \n" + + " }\n" + + " }\n" + + " \n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo(3)\n" + + " }\n" + + " macro foo(a1) {\n" + + " expression: bar(3)\n" + + " }\n" + + " macro bar(a1) {\n" + + " expression: a1*2\n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + builder.build(); + } + + @Test + public void testNoLoopWithMultipleInvocations() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field a type string { \n" + + " indexing: index \n" + + " }\n" + + " }\n" + + " \n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo(3)\n" + + " }\n" + + " macro foo(a1) {\n" + + " expression: bar(3) + bar(a1)\n" + + " }\n" + + " macro bar(a1) {\n" + + " expression: a1*2\n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + builder.build(); + } + + @Test + public void testNoLoopWithBoundIdentifiers() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " }\n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo(bar(2))\n" + + " }\n" + + " macro foo(x) {\n" + + " expression: x * x\n" + + " }\n" + + " macro bar(x) {\n" + + " expression: x + x\n" + + " }\n" + + " }\n" + + "}\n"); + builder.build(); + } + +} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/admin/DedicatedAdminV4Test.java b/config-model/src/test/java/com/yahoo/vespa/model/admin/DedicatedAdminV4Test.java index 7b586354394..c1edbec6bf5 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/admin/DedicatedAdminV4Test.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/admin/DedicatedAdminV4Test.java @@ -5,10 +5,15 @@ import com.yahoo.cloud.config.LogforwarderConfig; import com.yahoo.cloud.config.SentinelConfig; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.deploy.DeployProperties; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.provision.Hosts; import com.yahoo.config.model.provision.InMemoryProvisioner; import com.yahoo.config.model.test.MockApplicationPackage; +import com.yahoo.config.provision.Environment; +import com.yahoo.config.provision.RegionName; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.Zone; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.admin.monitoring.Metric; import com.yahoo.vespa.model.admin.monitoring.MetricsConsumer; @@ -183,6 +188,26 @@ public class DedicatedAdminV4Test { } } + @Test + public void testDedicatedLogserverInHostedVespa() throws IOException, SAXException { + String services = "<services>" + + " <admin version='4.0'>" + + " <logservers>" + + " <nodes count='1' dedicated='true'/>" + + " </logservers>" + + " </admin>" + + "</services>"; + + VespaModel model = createModel(hosts, services, new DeployState.Builder() + .zone(new Zone(SystemName.cd, Environment.dev, RegionName.defaultName())) + .properties(new DeployProperties.Builder() + .hostedVespa(true) + .build())); + assertEquals(1, model.getHosts().size()); + // Should create a container on the same node as logserver + assertHostContainsServices(model, "hosts/myhost0", "slobrok", "logd", "logserver", "container"); + } + private Set<String> serviceNames(VespaModel model, String hostname) { SentinelConfig config = model.getConfig(SentinelConfig.class, hostname); return config.service().stream().map(SentinelConfig.Service::name).collect(Collectors.toSet()); @@ -197,14 +222,18 @@ public class DedicatedAdminV4Test { } private VespaModel createModel(String hosts, String services) throws IOException, SAXException { + return createModel(hosts, services, new DeployState.Builder()); + } + + private VespaModel createModel(String hosts, String services, DeployState.Builder deployStateBuilder) throws IOException, SAXException { ApplicationPackage app = new MockApplicationPackage.Builder() .withHosts(hosts) .withServices(services) .build(); - return new VespaModel(new NullConfigModelRegistry(), - new DeployState.Builder().applicationPackage(app).modelHostProvisioner( - new InMemoryProvisioner(Hosts.readFrom(app.getHosts()), true)) - .build()); + return new VespaModel(new NullConfigModelRegistry(), deployStateBuilder + .applicationPackage(app) + .modelHostProvisioner(new InMemoryProvisioner(Hosts.readFrom(app.getHosts()), true)) + .build()); } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/MultipleRestApisTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/MultipleRestApisTest.java deleted file mode 100644 index d36ab74c6f1..00000000000 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/MultipleRestApisTest.java +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.model.container.jersey.xml; - -import com.yahoo.component.ComponentId; -import com.yahoo.config.model.builder.xml.test.DomBuilderTest; -import com.yahoo.container.ComponentsConfig; -import com.yahoo.container.di.config.JerseyBundlesConfig; -import com.yahoo.container.jdisc.JdiscBindingsConfig; -import com.yahoo.vespa.model.container.jersey.JerseyHandler; -import com.yahoo.vespa.model.container.jersey.RestApi; -import com.yahoo.vespa.model.container.jersey.RestApiContext; -import com.yahoo.vespa.model.container.xml.ContainerModelBuilderTestBase; -import org.junit.Before; -import org.junit.Test; - -import java.util.Map; - -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.hasItems; -import static org.hamcrest.CoreMatchers.not; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.nullValue; -import static org.hamcrest.core.Is.is; -import static org.junit.Assert.assertTrue; - -/** - * @author bjorncs - */ -public class MultipleRestApisTest extends ContainerModelBuilderTestBase { - - private static final String CLUSTER_ID = "container"; - private static final String PATH_1 = "rest_1"; - private static final String PATH_2 = "rest_2"; - private static final String HTTP_BINDING_1 = "http://*/" + PATH_1 + "/*"; - private static final String HTTPS_BINDING_1 = "https://*/" + PATH_1 + "/*"; - private static final String HTTP_BINDING_2 = "http://*/" + PATH_2 + "/*"; - private static final String HTTPS_BINDING_2 = "https://*/" + PATH_2 + "/*"; - private static final String HANDLER_ID_1 = JerseyHandler.CLASS + "-" + PATH_1; - private static final String HANDLER_ID_2 = JerseyHandler.CLASS + "-" + PATH_2; - private static final String REST_API_CONTEXT_ID_1 = RestApiContext.CONTAINER_CLASS + "-" + PATH_1; - private static final String REST_API_CONTEXT_ID_2 = RestApiContext.CONTAINER_CLASS + "-" + PATH_2; - private static final String REST_API_XML = - "<container version=\"1.0\" id=\"" + CLUSTER_ID + "\">\n" + - " <rest-api path=\"" + PATH_1 + "\">\n" + - " <components bundle=\"bundle1\" />\n" + - " </rest-api>\n" + - " <rest-api path=\"" + PATH_2 + "\">\n" + - " <components bundle=\"bundle2\" />\n" + - " </rest-api>\n" + - "</container>"; - - - private JerseyHandler handler1; - private JerseyHandler handler2; - private Map<ComponentId, RestApi> restApis; - - @Before - public void setup() throws Exception { - createModel(root, DomBuilderTest.parse(REST_API_XML)); - handler1 = (JerseyHandler)getContainerComponentNested(CLUSTER_ID, HANDLER_ID_1); - handler2 = (JerseyHandler)getContainerComponentNested(CLUSTER_ID, HANDLER_ID_2); - restApis = getContainerCluster(CLUSTER_ID).getRestApiMap(); - } - - @Test - public void cluster_has_all_rest_apis() { - assertThat(restApis.size(), is(2)); - } - - @Test - public void rest_apis_have_path_as_component_id() { - assertTrue(restApis.get(ComponentId.fromString(PATH_1)) instanceof RestApi); - assertTrue(restApis.get(ComponentId.fromString(PATH_2)) instanceof RestApi); - } - - @Test - public void jersey_handler_has_correct_bindings() { - assertThat(handler1, not(nullValue())); - assertThat(handler1.getServerBindings(), hasItems(HTTP_BINDING_1, HTTPS_BINDING_1)); - - assertThat(handler2, not(nullValue())); - assertThat(handler2.getServerBindings(), hasItems(HTTP_BINDING_2, HTTPS_BINDING_2)); - } - - @Test - public void jersey_bindings_are_included_in_config() { - JdiscBindingsConfig config = root.getConfig(JdiscBindingsConfig.class, CLUSTER_ID); - assertThat(config.handlers(HANDLER_ID_1).serverBindings(), hasItems(HTTP_BINDING_1, HTTPS_BINDING_1)); - assertThat(config.handlers(HANDLER_ID_2).serverBindings(), hasItems(HTTP_BINDING_2, HTTPS_BINDING_2)); - } - - - @Test - public void jersey_handler_for_each_rest_api_is_included_in_components_config() { - ComponentsConfig config = root.getConfig(ComponentsConfig.class, CLUSTER_ID); - assertThat(config.toString(), containsString(".id \"" + HANDLER_ID_1 + "\"")); - assertThat(config.toString(), containsString(".id \"" + HANDLER_ID_2 + "\"")); - } - - @Test - public void jersey_bundles_component_for_each_rest_api_is_included_in_components_config() { - - ComponentsConfig config = root.getConfig(ComponentsConfig.class, CLUSTER_ID); - assertThat(config.toString(), containsString(".id \"" + REST_API_CONTEXT_ID_1 + "\"")); - assertThat(config.toString(), containsString(".id \"" + REST_API_CONTEXT_ID_2 + "\"")); - } - - @Test - public void each_rest_api_has_correct_bundle() { - RestApiContext restApiContext1 = restApis.get(ComponentId.fromString(PATH_1)).getContext(); - RestApiContext restApiContext2 = restApis.get(ComponentId.fromString(PATH_2)).getContext(); - - JerseyBundlesConfig bundlesConfig1 = root.getConfig(JerseyBundlesConfig.class, restApiContext1.getConfigId()); - assertThat(bundlesConfig1.toString(), containsString("bundle1")); - assertThat(bundlesConfig1.toString(), not(containsString("bundle2"))); - - JerseyBundlesConfig bundlesConfig2 = root.getConfig(JerseyBundlesConfig.class, restApiContext2.getConfigId()); - assertThat(bundlesConfig2.toString(), containsString("bundle2")); - assertThat(bundlesConfig2.toString(), not(containsString("bundle1"))); - } -} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/RestApiTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/RestApiTest.java index 4b28dfa0b9d..503b38b79b4 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/RestApiTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/jersey/xml/RestApiTest.java @@ -2,28 +2,24 @@ package com.yahoo.vespa.model.container.jersey.xml; import com.yahoo.component.ComponentId; -import com.yahoo.config.model.builder.xml.test.DomBuilderTest; +import com.yahoo.config.model.test.TestUtil; import com.yahoo.container.ComponentsConfig; -import com.yahoo.container.config.jersey.JerseyInitConfig; import com.yahoo.container.di.config.JerseyBundlesConfig; -import com.yahoo.container.di.config.JerseyInjectionConfig; -import com.yahoo.container.jdisc.JdiscBindingsConfig; +import com.yahoo.jdisc.http.ServletPathsConfig; import com.yahoo.vespa.model.container.component.Component; -import com.yahoo.vespa.model.container.component.Handler; -import com.yahoo.vespa.model.container.jersey.JerseyHandler; +import com.yahoo.vespa.model.container.jersey.Jersey2Servlet; import com.yahoo.vespa.model.container.jersey.RestApi; import com.yahoo.vespa.model.container.jersey.RestApiContext; import com.yahoo.vespa.model.container.xml.ContainerModelBuilderTestBase; -import org.junit.Ignore; +import org.junit.Before; import org.junit.Test; +import org.w3c.dom.Element; import java.util.HashSet; import java.util.Set; import java.util.stream.Collectors; import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.hasItem; -import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.nullValue; @@ -32,103 +28,86 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; /** + * @author gjoranv * @author bjorncs */ public class RestApiTest extends ContainerModelBuilderTestBase { - private static final String Path = "rest/api"; - private static final String HttpBinding = "http://*/" + Path + "/*"; - private static final String HttpsBinding = "https://*/" + Path + "/*"; - private static final String HandlerId = JerseyHandler.CLASS + "-" + RestApi.idFromPath(Path); - private static final String RestApiContextId = RestApiContext.CONTAINER_CLASS + "-" + RestApi.idFromPath(Path); - private static final String InjectedComponentId = "injectedHandler"; - - private static final String ClusterId = "container"; - - private static final String restApiXml = - "<container version=\"1.0\" id=\"" + ClusterId + "\" jetty=\"true\">\n" + - " <rest-api path=\"" + Path + "\">\n" + - " <components bundle=\"my-jersey-bundle:1.0\">\n" + - " <package>com.yahoo.foo</package>\n" + - " </components>\n" + - " </rest-api>\n" + - " <handler id=\"" + InjectedComponentId + "\" />\n" + - "</container>"; + private static final String PATH = "rest/api"; + private static final String REST_API_CONTEXT_ID = RestApiContext.CONTAINER_CLASS + "-" + RestApi.idFromPath(PATH); + private static final String INJECTED_COMPONENT_ID = "injectedHandler"; + private static final String CLUSTER_ID = "container"; + + private static final Element restApiXml = TestUtil.parse( + "<container version=\"1.0\" id=\"" + CLUSTER_ID + "\">", + " <rest-api path=\"" + PATH + "\">", + " <components bundle=\"my-jersey-bundle:1.0\">", + " <package>com.yahoo.foo</package>", + " </components>", + " </rest-api>", + " <handler id=\"" + INJECTED_COMPONENT_ID + "\" />", + "</container>"); private RestApi restApi; - private JerseyHandler handler; + private Jersey2Servlet servlet; private RestApiContext context; + @Before public void setup() throws Exception { - createModel(root, DomBuilderTest.parse(restApiXml)); + createModel(root, restApiXml); root.validate(); - getContainerCluster(ClusterId).prepare(); - restApi = getContainerCluster(ClusterId).getRestApiMap().values().iterator().next(); - handler = (JerseyHandler) getContainerComponentNested(ClusterId, HandlerId); + getContainerCluster(CLUSTER_ID).prepare(); + restApi = getContainerCluster(CLUSTER_ID).getRestApiMap().values().iterator().next(); + servlet = restApi.getJersey2Servlet(); context = restApi.getContext(); } @Test - public void jersey_handler_has_correct_bindings() throws Exception { - setup(); - assertThat(handler, not(nullValue())); - assertThat(handler.getServerBindings(), hasItems(HttpBinding, HttpsBinding)); + public void jersey2_servlet_has_correct_binding_path() { + assertThat(servlet, not(nullValue())); + assertThat(servlet.bindingPath, is(PATH + "/*")); } @Test - public void jersey_bindings_are_included_in_config() throws Exception { - setup(); - JdiscBindingsConfig config = root.getConfig(JdiscBindingsConfig.class, ClusterId); - assertThat(config.handlers(HandlerId).serverBindings(), hasItems(HttpBinding, HttpsBinding)); + public void jersey2_servlet_has_correct_bundle_spec() { + assertThat(servlet.model.bundleInstantiationSpec.bundle.stringValue(), is(Jersey2Servlet.BUNDLE)); } @Test - public void jersey_handler_has_correct_bundle_spec() throws Exception { - setup(); - assertThat(handler.model.bundleInstantiationSpec.bundle.stringValue(), is(JerseyHandler.BUNDLE)); + public void rest_api_path_is_included_in_servlet_config() { + ServletPathsConfig config = root.getConfig(ServletPathsConfig.class, servlet.getConfigId()); + assertThat(config.servlets(servlet.getComponentId().stringValue()).path(), is(PATH + "/*")); } @Test - public void config_has_correct_jersey_mapping() throws Exception { - setup(); - JerseyInitConfig config = root.getConfig(JerseyInitConfig.class, handler.getConfigId()); - assertThat(config.jerseyMapping(), is(Path)); - } - - @Test - public void resource_bundles_are_included_in_config() throws Exception { - setup(); + public void resource_bundles_are_included_in_config() { JerseyBundlesConfig config = root.getConfig(JerseyBundlesConfig.class, context.getConfigId()); assertThat(config.bundles().size(), is(1)); assertThat(config.bundles(0).spec(), is("my-jersey-bundle:1.0")); } @Test - public void packages_to_scan_are_included_in_config() throws Exception { - setup(); + public void packages_to_scan_are_included_in_config() { JerseyBundlesConfig config = root.getConfig(JerseyBundlesConfig.class, context.getConfigId()); assertThat(config.bundles(0).packages(), contains("com.yahoo.foo")); } @Test - public void jersey_handler_is_included_in_components_config() throws Exception { - setup(); - ComponentsConfig config = root.getConfig(ComponentsConfig.class, ClusterId); - assertThat(config.toString(), containsString(".id \"" + HandlerId + "\"")); + public void jersey2_servlet_is_included_in_components_config() { + ComponentsConfig config = root.getConfig(ComponentsConfig.class, CLUSTER_ID); + assertThat(config.toString(), containsString(".id \"" + servlet.getComponentId().stringValue() + "\"")); } @Test - public void restApiContext_is_included_in_components_config() throws Exception { - setup(); - ComponentsConfig config = root.getConfig(ComponentsConfig.class, ClusterId); - assertThat(config.toString(), containsString(".id \"" + RestApiContextId + "\"")); + public void restApiContext_is_included_in_components_config() { + ComponentsConfig config = root.getConfig(ComponentsConfig.class, CLUSTER_ID); + assertThat(config.toString(), containsString(".id \"" + REST_API_CONTEXT_ID + "\"")); } @Test public void all_non_restApi_components_are_injected_to_RestApiContext() throws Exception { - setup(); - ComponentsConfig componentsConfig = root.getConfig(ComponentsConfig.class, ClusterId); + ComponentsConfig componentsConfig = root.getConfig(ComponentsConfig.class, CLUSTER_ID); - Set<ComponentId> clusterChildrenComponentIds = getContainerCluster(ClusterId).getAllComponents().stream() + Set<ComponentId> clusterChildrenComponentIds = getContainerCluster(CLUSTER_ID).getAllComponents().stream() .map(Component::getComponentId) .collect(Collectors.toSet()); @@ -136,7 +115,7 @@ public class RestApiTest extends ContainerModelBuilderTestBase { .map(child -> ((Component<?, ?>) child).getComponentId()) .collect(Collectors.toSet()); - //TODO: Review: replace with filtering against RestApiContext.isCycleGeneratingComponent + //TODO: try replacing with filtering against RestApiContext.isCycleGeneratingComponent ComponentId cycleInducingComponents = ComponentId.fromString("com.yahoo.container.handler.observability.ApplicationStatusHandler"); Set<ComponentId> expectedInjectedConfigIds = new HashSet<>(clusterChildrenComponentIds); @@ -165,49 +144,4 @@ public class RestApiTest extends ContainerModelBuilderTestBase { .get(); } - @Ignore // TODO: use for naming components instead - @Test - public void jdisc_components_can_be_injected() throws Exception { - setup(); - JerseyInjectionConfig config = root.getConfig(JerseyInjectionConfig.class, context.getConfigId()); - assertThat(config.inject(0).instance(), is("injectedHandler")); - assertThat(config.inject(0).forClass(), is("com.yahoo.handler.Handler")); - } - - @Ignore // TODO: use for naming a non-existent component instead - @Test(expected = IllegalArgumentException.class) - public void injecting_non_existent_component() throws Exception { - String restApiXml = - "<container version=\"1.0\" id=\"" + ClusterId + "\">\n" + - " <rest-api path=\"" + Path + "\">\n" + - " <components bundle=\"my-jersey-bundle:1.0\" />\n" + - " <inject jdisc-component=\"non-existent\" for-class=\"foo\" />\n" + - " </rest-api>\n" + - "</container>"; - createModel(root, DomBuilderTest.parse(restApiXml)); - root.validate(); - } - - @Test - public void legacy_syntax_should_produce_valid_model() throws Exception { - String legacyXml = - "<container version=\"1.0\" >\n" + - " <handler id=\"" + JerseyHandler.CLASS + "\" >\n" + - " <binding>" + HttpBinding + "</binding>\n" + - " <config name=\"jdisc.jersey.jersey-handler\">\n" + - " <jerseyMapping>jersey</jerseyMapping>\n" + - " </config>\n" + - " </handler>\n" + - "</container>"; - - createModel(root, DomBuilderTest.parse(legacyXml)); - - Handler<?> handler = (Handler<?>) getContainerComponent("container", JerseyHandler.CLASS); - assertThat(handler, not(nullValue())); - assertThat(handler.getServerBindings(), hasItem(HttpBinding)); - - JdiscBindingsConfig bindingsConfig = root.getConfig(JdiscBindingsConfig.class, ClusterId); - assertThat(bindingsConfig.handlers(JerseyHandler.CLASS).serverBindings(), hasItem(HttpBinding)); - } - } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java index e46e736dcd6..6a5611a7279 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java @@ -78,12 +78,4 @@ public abstract class ContainerModelBuilderTestBase { ComponentId.fromString(componentId)); } - // TODO: will not work with multiple instances of the same class - public Component<?, ?> getContainerComponentNested(String clusterId, String componentId) { - ComponentId id = ComponentId.fromString(componentId); - for (Component<?,?> component : getContainerCluster(clusterId).getAllComponents()) - if (id.equals(component.getComponentId())) - return component; - return null; - } } diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java b/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java index 5204da08307..6df617ea335 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/Capacity.java @@ -65,12 +65,6 @@ public final class Capacity { return fromNodeCount(capacity, Optional.empty(), false, true); } - // TODO: Remove after July 2018 - @Deprecated - public static Capacity fromNodeCount(int nodeCount, Optional<String> flavor, boolean required) { - return new Capacity(nodeCount, flavor, required, true, NodeType.tenant); - } - public static Capacity fromNodeCount(int nodeCount, Optional<String> flavor, boolean required, boolean canFail) { return new Capacity(nodeCount, flavor, required, canFail, NodeType.tenant); } diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/InstanceName.java b/config-provisioning/src/main/java/com/yahoo/config/provision/InstanceName.java index 0f1b298ba83..703528e5d33 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/InstanceName.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/InstanceName.java @@ -46,6 +46,10 @@ public class InstanceName implements Comparable<InstanceName> { return equals(InstanceName.defaultName()); } + public boolean isTester() { + return value().endsWith("-t"); + } + public String value() { return instanceName; } @Override diff --git a/config-proxy/src/main/sh/vespa-config-ctl.sh b/config-proxy/src/main/sh/vespa-config-ctl.sh index a670e69cdbf..649eef951c0 100755 --- a/config-proxy/src/main/sh/vespa-config-ctl.sh +++ b/config-proxy/src/main/sh/vespa-config-ctl.sh @@ -103,6 +103,7 @@ export LD_LIBRARY_PATH="$VESPA_HOME/lib64" case $1 in start) + nohup sbin/vespa-retention-enforcer > ${LOGDIR}/vre-start.log 2>&1 </dev/null & configsources=`bin/vespa-print-default configservers_rpc` userargs=$vespa_base__jvmargs_configproxy if [ "$userargs" == "" ]; then diff --git a/config/src/tests/failover/failover.cpp b/config/src/tests/failover/failover.cpp index 0f4a7e6bf6f..990ca761e7e 100644 --- a/config/src/tests/failover/failover.cpp +++ b/config/src/tests/failover/failover.cpp @@ -38,7 +38,7 @@ struct RPCServer : public FRT_Invokable { void init(FRT_Supervisor * s) { FRT_ReflectionBuilder rb(s); - rb.DefineMethod("config.v3.getConfig", requestTypes.c_str(), responseTypes.c_str(), true, + rb.DefineMethod("config.v3.getConfig", requestTypes.c_str(), responseTypes.c_str(), FRT_METHOD(RPCServer::getConfig), this); } diff --git a/config/src/tests/file_acquirer/file_acquirer_test.cpp b/config/src/tests/file_acquirer/file_acquirer_test.cpp index 0d2e2bf9144..0453c6ddbd0 100644 --- a/config/src/tests/file_acquirer/file_acquirer_test.cpp +++ b/config/src/tests/file_acquirer/file_acquirer_test.cpp @@ -11,7 +11,7 @@ struct ServerFixture : FRT_Invokable { vespalib::string spec; void init_rpc() { FRT_ReflectionBuilder rb(&orb); - rb.DefineMethod("waitFor", "s", "s", true, FRT_METHOD(ServerFixture::RPC_waitFor), this); + rb.DefineMethod("waitFor", "s", "s", FRT_METHOD(ServerFixture::RPC_waitFor), this); rb.MethodDesc("wait for and resolve file reference"); rb.ParamDesc("file_ref", "file reference to wait for and resolve"); rb.ReturnDesc("file_path", "actual path to the requested file"); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java index a8b4844ca43..79a8a3d8763 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java @@ -691,13 +691,14 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye } /** Returns version to use when deploying application in given environment */ - static Version decideVersion(ApplicationId application, Environment environment, Version targetVersion, boolean bootstrap) { - if (environment.isManuallyDeployed() && - !"hosted-vespa".equals(application.tenant().value()) && // Never change version of system applications - !bootstrap) { // Do not use current version when bootstrapping config server + static Version decideVersion(ApplicationId application, Environment environment, Version sessionVersion, boolean bootstrap) { + if ( environment.isManuallyDeployed() + && ! "hosted-vespa".equals(application.tenant().value()) // Never change version of system applications + && ! application.instance().isTester() // Never upgrade tester containers + && ! bootstrap) { // Do not use current version when bootstrapping config server return Vtag.currentVersion; } - return targetVersion; + return sessionVersion; } public Slime createDeployLog() { diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java index d9a653a1dc2..a5e76262f48 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java @@ -135,19 +135,25 @@ public class ApplicationRepositoryTest { public void decideVersion() { ApplicationId regularApp = ApplicationId.from("tenant1", "application1", "default"); ApplicationId systemApp = ApplicationId.from("hosted-vespa", "routing", "default"); - Version targetVersion = Version.fromString("5.0"); + ApplicationId testerApp = ApplicationId.from("tenant1", "application1", "default-t"); + Version sessionVersion = Version.fromString("5.0"); - // Always use target for system application - assertEquals(targetVersion, ApplicationRepository.decideVersion(systemApp, Environment.prod, targetVersion, false)); - assertEquals(targetVersion, ApplicationRepository.decideVersion(systemApp, Environment.dev, targetVersion, false)); - assertEquals(targetVersion, ApplicationRepository.decideVersion(systemApp, Environment.perf, targetVersion, false)); + // Always use session version for system application + assertEquals(sessionVersion, ApplicationRepository.decideVersion(systemApp, Environment.prod, sessionVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(systemApp, Environment.dev, sessionVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(systemApp, Environment.perf, sessionVersion, false)); + + // Always use session version for tester application + assertEquals(sessionVersion, ApplicationRepository.decideVersion(testerApp, Environment.prod, sessionVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(testerApp, Environment.dev, sessionVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(testerApp, Environment.perf, sessionVersion, false)); // Target for regular application depends on environment - assertEquals(targetVersion, ApplicationRepository.decideVersion(regularApp, Environment.prod, targetVersion, false)); - assertEquals(Vtag.currentVersion, ApplicationRepository.decideVersion(regularApp, Environment.dev, targetVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(regularApp, Environment.prod, sessionVersion, false)); + assertEquals(Vtag.currentVersion, ApplicationRepository.decideVersion(regularApp, Environment.dev, sessionVersion, false)); // If bootstrap, version should be target version - assertEquals(targetVersion, ApplicationRepository.decideVersion(regularApp, Environment.dev, targetVersion, true)); - assertEquals(Vtag.currentVersion, ApplicationRepository.decideVersion(regularApp, Environment.perf, targetVersion, false)); + assertEquals(sessionVersion, ApplicationRepository.decideVersion(regularApp, Environment.dev, sessionVersion, true)); + assertEquals(Vtag.currentVersion, ApplicationRepository.decideVersion(regularApp, Environment.perf, sessionVersion, false)); } @Test diff --git a/container-accesslogging/pom.xml b/container-accesslogging/pom.xml index 0d7b134c58c..b2c9bd7db8b 100644 --- a/container-accesslogging/pom.xml +++ b/container-accesslogging/pom.xml @@ -87,17 +87,6 @@ <artifactId>bundle-plugin</artifactId> <extensions>true</extensions> </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-compiler-plugin</artifactId> - <configuration> - <compilerArgs> - <arg>-Xlint:all</arg> - <arg>-Xlint:-serial</arg> - <arg>-Werror</arg> - </compilerArgs> - </configuration> - </plugin> </plugins> <outputDirectory>${buildOutputDirectory}</outputDirectory> </build> diff --git a/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java b/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java index adadd0b1414..595bd99a759 100644 --- a/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java +++ b/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java @@ -175,7 +175,7 @@ public class JSONFormatter { duration = new BigDecimal(0xffffffff); } - return duration.setScale(3, BigDecimal.ROUND_HALF_UP); + return duration.setScale(3, RoundingMode.HALF_UP); } private static String getNormalizedURI(String rawPath, String rawQuery) { diff --git a/container-accesslogging/src/main/java/com/yahoo/container/logging/LogFileHandler.java b/container-accesslogging/src/main/java/com/yahoo/container/logging/LogFileHandler.java index c7e2a777695..d729b092670 100644 --- a/container-accesslogging/src/main/java/com/yahoo/container/logging/LogFileHandler.java +++ b/container-accesslogging/src/main/java/com/yahoo/container/logging/LogFileHandler.java @@ -2,6 +2,7 @@ package com.yahoo.container.logging; import com.yahoo.container.core.AccessLogConfig; +import com.yahoo.log.LogFileDb; import java.io.File; import java.io.FileOutputStream; @@ -250,6 +251,7 @@ public class LogFileHandler extends StreamHandler { FileOutputStream os = new FileOutputStream(fileName, true); // append mode, for safety super.setOutputStream(os); currentOutputStream = os; + if (! useSequenceNameScheme) LogFileDb.nowLoggingTo(fileName); } catch (IOException e) { throw new RuntimeException("Couldn't open log file '" + fileName + "'", e); @@ -310,7 +312,9 @@ public class LogFileHandler extends StreamHandler { if (thisN>largestN) largestN=thisN; } - file.renameTo(new File(dir,file.getName() + "." + (largestN + 1))); + File newFn = new File(dir, file.getName() + "." + (largestN + 1)); + LogFileDb.nowLoggingTo(newFn.getAbsolutePath()); + file.renameTo(newFn); } /** diff --git a/container-accesslogging/src/main/resources/configdefinitions/access-log.def b/container-accesslogging/src/main/resources/configdefinitions/access-log.def index 276128e0405..9df9299ae19 100644 --- a/container-accesslogging/src/main/resources/configdefinitions/access-log.def +++ b/container-accesslogging/src/main/resources/configdefinitions/access-log.def @@ -1,11 +1,16 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. namespace=container.core - # File name patterns supporting the expected time variables, e.g. ".%Y%m%d%H%M%S" fileHandler.pattern string + +# When should rotation happen, in minutes after midnight +# Does this really need to be configurable? +# Could just configure "every N minutes" instead fileHandler.rotation string default="0 60 ..." +# TODO remove in Vespa 7, always use DATE +# # Defines how file rotation is done. There are two options: # # DATE: @@ -27,4 +32,5 @@ fileHandler.rotateScheme enum {DATE, SEQUENCE} default=DATE fileHandler.symlink string default="" # compress the previous access log after rotation +# TODO change to "true" for Vespa 7 fileHandler.compressOnRotation bool default=false diff --git a/container-core/src/main/java/com/yahoo/container/Container.java b/container-core/src/main/java/com/yahoo/container/Container.java index efe7c58563c..7e5ea7bd948 100755 --- a/container-core/src/main/java/com/yahoo/container/Container.java +++ b/container-core/src/main/java/com/yahoo/container/Container.java @@ -145,7 +145,7 @@ public class Container { /** * Only for internal use. */ - public void setCustomFileAcquirer(final FileAcquirer fileAcquirer) { + public void setCustomFileAcquirer(FileAcquirer fileAcquirer) { if (this.fileAcquirer != null) { throw new RuntimeException("Can't change file acquirer. Is " + this.fileAcquirer + " attempted to set to " + fileAcquirer); @@ -155,7 +155,7 @@ public class Container { setPathAcquirer(fileAcquirer); } - private static void setPathAcquirer(final FileAcquirer fileAcquirer) { + private static void setPathAcquirer(FileAcquirer fileAcquirer) { ConfigTransformer.setPathAcquirer(fileReference -> { try { return fileAcquirer.waitFor(fileReference, 15, TimeUnit.MINUTES).toPath(); diff --git a/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java b/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java index c94dc30fd6f..ee12c7d4c9f 100644 --- a/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java +++ b/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java @@ -3,12 +3,13 @@ package com.yahoo.container.core; /** * @author gjoranv - * @since 5.46 */ public interface BundleLoaderProperties { + // TODO: This should be removed. The prefix is used to separate the bundles in BundlesConfig // into those that are transferred with filedistribution and those that are preinstalled // on disk. Instead, the model should have put them in two different configs. I.e. create a new // config 'preinstalled-bundles.def'. - public static final String DISK_BUNDLE_PREFIX = "file:"; + String DISK_BUNDLE_PREFIX = "file:"; + } diff --git a/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java b/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java index eceb41f9739..557f331395b 100644 --- a/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java +++ b/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java @@ -41,8 +41,7 @@ public class BundleLoader { initialBundles = Arrays.asList(osgi.getBundles()); } - private List<Bundle> obtainBundles(FileReference reference, FileAcquirer fileAcquirer) - throws InterruptedException { + private List<Bundle> obtainBundles(FileReference reference, FileAcquirer fileAcquirer) throws InterruptedException { File file = fileAcquirer.waitFor(reference, 7, TimeUnit.DAYS); return osgi.install(file.getAbsolutePath()); } @@ -95,7 +94,7 @@ public class BundleLoader { log.info("Installing bundle from disk with reference '" + reference.value() + "'"); File file = new File(referenceFileName); - if (!file.exists()) { + if ( ! file.exists()) { throw new IllegalArgumentException("Reference '" + reference.value() + "' not found on disk."); } diff --git a/container-dependency-versions/pom.xml b/container-dependency-versions/pom.xml index ccb8a9c311c..259fcfb8de7 100644 --- a/container-dependency-versions/pom.xml +++ b/container-dependency-versions/pom.xml @@ -466,7 +466,7 @@ <guava.version>18.0</guava.version> <guice.version>3.0</guice.version> <jaxb.version>2.3.0</jaxb.version> - <jetty.version>9.4.10.v20180503</jetty.version> + <jetty.version>9.4.12.v20180830</jetty.version> <slf4j.version>1.7.5</slf4j.version> <!-- These must be kept in sync with version used by current jersey2.version. --> diff --git a/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentGraph.java b/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentGraph.java index 463de0c089a..76ca94c9286 100644 --- a/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentGraph.java +++ b/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentGraph.java @@ -77,7 +77,6 @@ public class ComponentGraph { private Optional<Node> lookupGlobalComponent(Key<?> key) { if (!(key.getTypeLiteral().getType() instanceof Class)) { - throw new RuntimeException("Type not supported " + key.getTypeLiteral()); } Class<?> clazz = key.getTypeLiteral().getRawType(); diff --git a/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/Exceptions.java b/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/Exceptions.java index d84d771fef6..e8c527aeaef 100644 --- a/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/Exceptions.java +++ b/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/Exceptions.java @@ -3,6 +3,7 @@ package com.yahoo.container.di.componentgraph.core; import java.util.Arrays; class Exceptions { + static <E extends Throwable> E removeStackTrace(E exception) { if (preserveStackTrace()) { return exception; diff --git a/container-disc/CMakeLists.txt b/container-disc/CMakeLists.txt index 1b661020166..92f5b303d41 100644 --- a/container-disc/CMakeLists.txt +++ b/container-disc/CMakeLists.txt @@ -6,7 +6,6 @@ vespa_install_script(src/main/sh/vespa-start-container-daemon.sh vespa-start-con install_config_definition(src/main/resources/configdefinitions/container.jdisc.config.http-server.def) install_config_definition(src/main/resources/configdefinitions/jdisc-bindings.def container.jdisc.jdisc-bindings.def) install_config_definition(src/main/resources/configdefinitions/jersey-connection.def container.config.jersey.jersey-connection.def) -install_config_definition(src/main/resources/configdefinitions/jersey-init.def container.config.jersey.jersey-init.def) install_config_definition(src/main/resources/configdefinitions/jersey-web-app-pool.def container.config.jersey.jersey-web-app-pool.def) install_config_definition(src/main/resources/configdefinitions/metric-defaults.def container.jdisc.config.metric-defaults.def) install_config_definition(src/main/resources/configdefinitions/score-board.def jdisc.metrics.yamasconsumer.cloud.score-board.def) diff --git a/container-disc/src/main/resources/configdefinitions/jersey-init.def b/container-disc/src/main/resources/configdefinitions/jersey-init.def deleted file mode 100644 index 95ec9f23906..00000000000 --- a/container-disc/src/main/resources/configdefinitions/jersey-init.def +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -## Do NOT move this file to the container-jersey module. If system bundles -## like config-model import packages from container-jersey, new class -## loaders for these bundles will be created after reconfig. -namespace=container.config.jersey - -# Controlled by the config framework, do not set this from services.xml! -jerseyMapping string diff --git a/container-disc/src/main/sh/vespa-start-container-daemon.sh b/container-disc/src/main/sh/vespa-start-container-daemon.sh index e6219ab0467..21c9dc28022 100755 --- a/container-disc/src/main/sh/vespa-start-container-daemon.sh +++ b/container-disc/src/main/sh/vespa-start-container-daemon.sh @@ -52,7 +52,7 @@ getconfig() { qrstartcfg="`cat ${config_dir}/qr-start.cfg`" ;; *) - qrstartcfg="`$VESPA_HOME/bin/vespa-get-config -w 10 -n search.config.qr-start -i ${VESPA_CONFIG_ID}`" + qrstartcfg="`$VESPA_HOME/bin/vespa-get-config -l -w 10 -n search.config.qr-start -i ${VESPA_CONFIG_ID}`" ;; esac cmds=`echo "$qrstartcfg" | perl -ne 's/^(\w+)\.(\w+) (.*)/$1_$2=$3/ && print'` diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java index 0f067f33b79..336efcdfbc3 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java @@ -222,9 +222,9 @@ public class FastSearcher extends VespaBackEndSearcher { */ private CloseableChannel getChannel(Query query) { if (query.properties().getBoolean(dispatchInternal, false)) { - Optional<CloseableChannel> directDispatchChannel = dispatcher.getDispatchBackend(query); - if(directDispatchChannel.isPresent()) { - return directDispatchChannel.get(); + Optional<CloseableChannel> dispatchedChannel = dispatcher.getDispatchedChannel(query); + if (dispatchedChannel.isPresent()) { + return dispatchedChannel.get(); } } if (!query.properties().getBoolean(dispatchDirect, true)) diff --git a/container-search/src/main/java/com/yahoo/prelude/query/CompositeItem.java b/container-search/src/main/java/com/yahoo/prelude/query/CompositeItem.java index 2c05f2e7edf..eee9949d831 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/CompositeItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/CompositeItem.java @@ -73,8 +73,7 @@ public abstract class CompositeItem extends Item { */ public void addItem(int index, Item item) { if (index > subitems.size() || index < 0) { - throw new IndexOutOfBoundsException( - "Could not add a subitem at position " + index + " to " + this); + throw new IndexOutOfBoundsException("Could not add a subitem at position " + index + " to " + this); } adding(item); subitems.add(index, item); diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NonReducibleCompositeItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NonReducibleCompositeItem.java index 547825cb51c..84aa177369a 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/NonReducibleCompositeItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/NonReducibleCompositeItem.java @@ -7,10 +7,9 @@ package com.yahoo.prelude.query; * <p> * Most composites, like AND and OR, are reducible as e.g (AND a) is semantically equal to (a). * <p> - * This type functions as a marked interfaces for query rewriters. + * This type functions as a marker type for query rewriters. * * @author bratseth - * @since 5.1.22 */ public abstract class NonReducibleCompositeItem extends CompositeItem { } diff --git a/container-search/src/main/java/com/yahoo/prelude/query/SameElementItem.java b/container-search/src/main/java/com/yahoo/prelude/query/SameElementItem.java index ca2c5a80283..aa446140da0 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/SameElementItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/SameElementItem.java @@ -11,6 +11,7 @@ import java.util.Iterator; * This represents a query where all terms are required to match in the same element id. * The primary usecase is to allow efficient search in arrays and maps of struct. * The common path is the field name containing the struct. + * * @author baldersheim */ @Beta diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/CloseableChannel.java b/container-search/src/main/java/com/yahoo/search/dispatch/CloseableChannel.java index 9a3e7e71031..838afa0c7fc 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/CloseableChannel.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/CloseableChannel.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.dispatch; import com.yahoo.fs4.BasicPacket; @@ -11,6 +12,9 @@ import java.io.Closeable; import java.io.IOException; import java.util.Optional; +/** + * @author ollivir + */ public class CloseableChannel implements Closeable { private FS4Channel channel; diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/DispatchedChannel.java b/container-search/src/main/java/com/yahoo/search/dispatch/DispatchedChannel.java new file mode 100644 index 00000000000..00c59fbc979 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/DispatchedChannel.java @@ -0,0 +1,38 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import com.yahoo.prelude.fastsearch.FS4ResourcePool; +import com.yahoo.search.dispatch.SearchCluster.Group; +import com.yahoo.search.dispatch.SearchCluster.Node; + +import java.util.Optional; + +/** + * An extension to CloseableChannel that encapsulates the release of a LoadBalancer group allocation. + * + * @author ollivir + */ +public class DispatchedChannel extends CloseableChannel { + private final SearchCluster.Group group; + private final LoadBalancer loadBalancer; + private boolean groupAllocated = true; + + public DispatchedChannel(FS4ResourcePool fs4ResourcePool, LoadBalancer loadBalancer, Group group, Node node) { + super(fs4ResourcePool.getBackend(node.hostname(), node.fs4port(), Optional.of(node.key()))); + + this.loadBalancer = loadBalancer; + this.group = group; + } + + public DispatchedChannel(FS4ResourcePool fs4ResourcePool, LoadBalancer loadBalancer, Group group) { + this(fs4ResourcePool, loadBalancer, group, group.nodes().iterator().next()); + } + + public void close() { + if (groupAllocated) { + groupAllocated = false; + loadBalancer.releaseGroup(group); + } + super.close(); + } +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java index d0f03dde3dd..be7cfea2017 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java @@ -11,7 +11,6 @@ import com.yahoo.container.protect.Error; import com.yahoo.prelude.fastsearch.DocumentDatabase; import com.yahoo.slime.ArrayTraverser; import com.yahoo.data.access.slime.SlimeAdapter; -import com.yahoo.fs4.mplex.Backend; import com.yahoo.prelude.fastsearch.FS4ResourcePool; import com.yahoo.prelude.fastsearch.FastHit; import com.yahoo.prelude.fastsearch.TimeoutException; @@ -284,19 +283,17 @@ public class Dispatcher extends AbstractComponent { } - public Optional<CloseableChannel> getDispatchBackend(Query query) { - Optional<SearchCluster.Group> groupInCluster = loadBalancer.getGroupForQuery(query); + public Optional<CloseableChannel> getDispatchedChannel(Query query) { + Optional<SearchCluster.Group> groupInCluster = loadBalancer.takeGroupForQuery(query); return groupInCluster.flatMap(group -> { if(group.nodes().size() == 1) { - return Optional.of(group.nodes().get(0)); + query.trace(false, 2, "Dispatching directly (anywhere) to ", group); + return Optional.of(new DispatchedChannel(fs4ResourcePool, loadBalancer, group)); } else { + loadBalancer.releaseGroup(group); return Optional.empty(); } - }).map(node -> { - query.trace(false, 2, "Dispatching directly (anywhere) to ", node); - Backend backend = fs4ResourcePool.getBackend(node.hostname(), node.fs4port(), Optional.of(node.key())); - return new CloseableChannel(backend); }); } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java b/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java index 8e90eae0eb3..d8e12980472 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java @@ -1,27 +1,138 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.dispatch; import com.yahoo.search.Query; import com.yahoo.search.dispatch.SearchCluster.Group; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; +/** + * LoadBalancer determines which group of content nodes should be accessed next for each search query when the internal java dispatcher is + * used. + * + * @author ollivir + */ public class LoadBalancer { + // The implementation here is a simplistic least queries in flight + round-robin load balancer + // TODO: consider the options in com.yahoo.vespa.model.content.TuningDispatch - private final SearchCluster searchCluster; + private final static Logger log = Logger.getLogger(LoadBalancer.class.getName()); + + private final boolean isInternallyDispatchable; + private final List<GroupSchedule> scoreboard; + private int needle = 0; public LoadBalancer(SearchCluster searchCluster) { - this.searchCluster = searchCluster; + if (searchCluster == null) { + this.isInternallyDispatchable = false; + this.scoreboard = null; + return; + } + this.isInternallyDispatchable = (searchCluster.groupSize() == 1); + this.scoreboard = new ArrayList<>(searchCluster.groups().size()); + + for (Group group : searchCluster.groups().values()) { + scoreboard.add(new GroupSchedule(group)); + } + Collections.shuffle(scoreboard); + } + + /** + * Select and allocate the search cluster group which is to be used for the provided query. Callers <b>must</b> call + * {@link #releaseGroup(Group)} symmetrically for each taken allocation. + * + * @param query + * @return The node group to target, or <i>empty</i> if the internal dispatch logic cannot be used + */ + public Optional<Group> takeGroupForQuery(Query query) { + if (!isInternallyDispatchable) { + return Optional.empty(); + } + + return allocateNextGroup(); } - public Optional<Group> getGroupForQuery(Query query) { - if (searchCluster.groups().size() == 1) { - for(Group group: searchCluster.groups().values()) { - // since the number of groups is 1, this will run only once - if(group.nodes().size() == 1) { - return Optional.of(group); + /** + * Release an allocation given by {@link #takeGroupForQuery(Query)}. The release must be done exactly once for each allocation. + * + * @param group + * previously allocated group + */ + public void releaseGroup(Group group) { + synchronized (this) { + for (GroupSchedule sched : scoreboard) { + if (sched.group.id() == group.id()) { + sched.adjustScore(-1); + break; } } } - return Optional.empty(); + } + + private Optional<Group> allocateNextGroup() { + synchronized (this) { + GroupSchedule bestSchedule = null; + + int index = needle; + for (int i = 0; i < scoreboard.size(); i++) { + GroupSchedule sched = scoreboard.get(index); + if (sched.isPreferredOver(bestSchedule)) { + bestSchedule = sched; + } + index = nextScoreboardIndex(index); + } + needle = nextScoreboardIndex(needle); + + Group ret = null; + if (bestSchedule != null) { + bestSchedule.adjustScore(1); + ret = bestSchedule.group; + } + if (log.isLoggable(Level.FINE)) { + log.fine("Offering <" + ret + "> for query connection"); + } + return Optional.ofNullable(ret); + } + } + + private int nextScoreboardIndex(int current) { + int next = current + 1; + if (next >= scoreboard.size()) { + next %= scoreboard.size(); + } + return next; + } + + private static class GroupSchedule { + private final Group group; + private int score; + + public GroupSchedule(Group group) { + this.group = group; + this.score = 0; + } + + public boolean isPreferredOver(GroupSchedule other) { + if (! group.hasSufficientCoverage()) { + return false; + } + if (other == null) { + return true; + } + return this.score < other.score; + } + + public void adjustScore(int amount) { + this.score += amount; + if (score < 0) { + log.warning("Double free of query target group detected"); + score = 0; + } + } } } diff --git a/container-search/src/main/java/com/yahoo/search/grouping/GroupingValidator.java b/container-search/src/main/java/com/yahoo/search/grouping/GroupingValidator.java index d96f490909e..06b030dbc78 100644 --- a/container-search/src/main/java/com/yahoo/search/grouping/GroupingValidator.java +++ b/container-search/src/main/java/com/yahoo/search/grouping/GroupingValidator.java @@ -5,6 +5,7 @@ import com.google.inject.Inject; import com.yahoo.component.chain.dependencies.After; import com.yahoo.component.chain.dependencies.Before; import com.yahoo.component.chain.dependencies.Provides; +import com.yahoo.search.grouping.request.AttributeMapLookupValue; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.container.QrSearchersConfig; import com.yahoo.processing.request.CompoundName; @@ -18,8 +19,7 @@ import com.yahoo.search.grouping.request.GroupingExpression; import com.yahoo.search.searchchain.Execution; import com.yahoo.search.searchchain.PhaseNames; -import java.util.HashSet; -import java.util.Set; +import java.util.HashMap; import static com.yahoo.search.grouping.GroupingQueryParser.SELECT_PARAMETER_PARSING; @@ -37,7 +37,7 @@ public class GroupingValidator extends Searcher { public static final String GROUPING_VALIDATED = "GroupingValidated"; public static final CompoundName PARAM_ENABLED = new CompoundName("validate_" + GroupingQueryParser.PARAM_REQUEST); - private final Set<String> attributeNames = new HashSet<>(); + private final HashMap<String, AttributesConfig.Attribute> attributes = new HashMap<>(); private final String clusterName; private final boolean enabled; @@ -55,7 +55,7 @@ public class GroupingValidator extends Searcher { enabled = (indexingMode != QrSearchersConfig.Searchcluster.Indexingmode.STREAMING); clusterName = enabled ? qrsConfig.searchcluster(clusterId).name() : null; for (AttributesConfig.Attribute attr : attributesConfig.attribute()) { - attributeNames.add(attr.name()); + attributes.put(attr.name(), attr); } } @@ -69,15 +69,42 @@ public class GroupingValidator extends Searcher { return execution.search(query); } + private void verifyHasAttribute(String attributeName) { + if (!attributes.containsKey(attributeName)) { + throw new UnavailableAttributeException(clusterName, attributeName); + } + } + + private void verifyCompatibleAttributeTypes(String keyAttributeName, + String keySourceAttributeName) { + AttributesConfig.Attribute keyAttribute = attributes.get(keyAttributeName); + AttributesConfig.Attribute keySourceAttribute = attributes.get(keySourceAttributeName); + if (!keySourceAttribute.datatype().equals(keyAttribute.datatype())) { + throw new IllegalArgumentException("Grouping request references key source attribute '" + + keySourceAttributeName + "' with data type '" + keySourceAttribute.datatype() + + "' that is different than data type '" + keyAttribute.datatype() + "' of key attribute '" + + keyAttributeName + "'"); + } + if (!keySourceAttribute.collectiontype().equals(AttributesConfig.Attribute.Collectiontype.Enum.SINGLE)) { + throw new IllegalArgumentException("Grouping request references key source attribute '" + + keySourceAttributeName + "' which is not of single value type"); + } + } + private class MyVisitor implements ExpressionVisitor { @Override public void visitExpression(GroupingExpression exp) { - if (exp instanceof AttributeValue) { - String name = ((AttributeValue)exp).getAttributeName(); - if (!attributeNames.contains(name)) { - throw new UnavailableAttributeException(clusterName, name); + if (exp instanceof AttributeMapLookupValue) { + AttributeMapLookupValue mapLookup = (AttributeMapLookupValue) exp; + verifyHasAttribute(mapLookup.getKeyAttribute()); + verifyHasAttribute(mapLookup.getValueAttribute()); + if (mapLookup.hasKeySourceAttribute()) { + verifyHasAttribute(mapLookup.getKeySourceAttribute()); + verifyCompatibleAttributeTypes(mapLookup.getKeyAttribute(), mapLookup.getKeySourceAttribute()); } + } else if (exp instanceof AttributeValue) { + verifyHasAttribute(((AttributeValue) exp).getAttributeName()); } } } diff --git a/container-search/src/main/java/com/yahoo/search/grouping/request/AttributeMapLookupValue.java b/container-search/src/main/java/com/yahoo/search/grouping/request/AttributeMapLookupValue.java new file mode 100644 index 00000000000..82c4b6763d8 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/grouping/request/AttributeMapLookupValue.java @@ -0,0 +1,61 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.grouping.request; + +/** + * This class represents a lookup in a map attribute in a {@link GroupingExpression}. + * + * It evaluates to the value found using the given key for the lookup in that attribute. + * The key is either specified explicitly or found via a key source attribute. + * + * @author geirst + */ +public class AttributeMapLookupValue extends AttributeValue { + + private final String prefix; + private final String suffix; + private final String key; + private final String keySourceAttribute; + + private AttributeMapLookupValue(String attributeValue, String prefix, String suffix, String key, String keySourceAttribute) { + super(attributeValue); + this.prefix = prefix; + this.suffix = suffix; + this.key = key; + this.keySourceAttribute = keySourceAttribute; + } + + public static AttributeMapLookupValue fromKey(String prefix, String key, String suffix) { + return new AttributeMapLookupValue(prefix + "{\"" + key + "\"}" + suffix, + prefix, suffix, key, ""); + } + + public static AttributeMapLookupValue fromKeySourceAttribute(String prefix, String keySourceAttribute, String suffix) { + return new AttributeMapLookupValue(prefix + "{attribute(" + keySourceAttribute + ")}" + suffix, + prefix, suffix, "", keySourceAttribute); + } + + @Override + public AttributeMapLookupValue copy() { + return new AttributeMapLookupValue(getAttributeName(), prefix, suffix, key, keySourceAttribute); + } + + public String getKeyAttribute() { + return prefix + ".key"; + } + + public String getValueAttribute() { + return prefix + ".value" + suffix; + } + + public String getKey() { + return key; + } + + public boolean hasKeySourceAttribute() { + return !keySourceAttribute.isEmpty(); + } + + public String getKeySourceAttribute() { + return keySourceAttribute; + } +} diff --git a/container-search/src/main/java/com/yahoo/search/grouping/vespa/ExpressionConverter.java b/container-search/src/main/java/com/yahoo/search/grouping/vespa/ExpressionConverter.java index d2dfb3c0ee7..95384fb12d3 100644 --- a/container-search/src/main/java/com/yahoo/search/grouping/vespa/ExpressionConverter.java +++ b/container-search/src/main/java/com/yahoo/search/grouping/vespa/ExpressionConverter.java @@ -6,6 +6,7 @@ import com.yahoo.search.grouping.request.AggregatorNode; import com.yahoo.search.grouping.request.AndFunction; import com.yahoo.search.grouping.request.ArrayAtLookup; import com.yahoo.search.grouping.request.AttributeFunction; +import com.yahoo.search.grouping.request.AttributeMapLookupValue; import com.yahoo.search.grouping.request.AttributeValue; import com.yahoo.search.grouping.request.AvgAggregator; import com.yahoo.search.grouping.request.BucketValue; @@ -263,6 +264,9 @@ class ExpressionConverter { if (exp instanceof AndFunction) { return addArguments(new AndFunctionNode(), (AndFunction)exp); } + if (exp instanceof AttributeMapLookupValue) { + return new AttributeNode(((AttributeMapLookupValue)exp).getAttributeName()); + } if (exp instanceof AttributeValue) { return new AttributeNode(((AttributeValue)exp).getAttributeName()); } diff --git a/container-search/src/main/java/com/yahoo/search/query/rewrite/QueryRewriteSearcher.java b/container-search/src/main/java/com/yahoo/search/query/rewrite/QueryRewriteSearcher.java index 2d0ff0c62db..7b24a00cf60 100644 --- a/container-search/src/main/java/com/yahoo/search/query/rewrite/QueryRewriteSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/query/rewrite/QueryRewriteSearcher.java @@ -36,14 +36,13 @@ import java.util.logging.Logger; public abstract class QueryRewriteSearcher extends Searcher { // Indicate whether rewriter is properly initiated - private boolean isOk = false; + private boolean isOk; protected final Logger logger = Logger.getLogger(QueryRewriteSearcher.class.getName()); // HashMap which store the rewriter dicts // It has the following format: - // HashMap<String(e.g. dictionary name, etc), - // Object(e.g. FSA, etc)>> + // HashMap<String(e.g. dictionary name, etc), Object(e.g. FSA, etc)>> protected HashMap<String, Object> rewriterDicts = new HashMap<>(); /** @@ -201,14 +200,14 @@ public abstract class QueryRewriteSearcher extends Searcher { "FSA file location for " + fsaName + ": " + fsaPath); // Retrieve FSA File handler - File fsaFile = null; - if(fileAcquirer!=null) { + File fsaFile; + if (fileAcquirer != null) { fsaFile = fileAcquirer.waitFor(fsaPath, 5, TimeUnit.MINUTES); - } else if(fileList!=null) { + } else { fsaFile = fileList.get(fsaName); } - if(fsaFile==null) { + if (fsaFile == null) { RewriterUtils.error(logger, "Error loading FSA dictionary file handler"); return false; } diff --git a/container-search/src/main/javacc/com/yahoo/search/grouping/request/parser/GroupingParser.jj b/container-search/src/main/javacc/com/yahoo/search/grouping/request/parser/GroupingParser.jj index 0678b030bc5..6a55a32eb8a 100644 --- a/container-search/src/main/javacc/com/yahoo/search/grouping/request/parser/GroupingParser.jj +++ b/container-search/src/main/javacc/com/yahoo/search/grouping/request/parser/GroupingParser.jj @@ -404,14 +404,31 @@ AndFunction andFunction(GroupingOperation grp) : AttributeValue attributeValue() : { - StringBuilder ret = new StringBuilder(); + StringBuilder prefix = new StringBuilder(); + StringBuilder suffix = new StringBuilder(); String str; + String key = null; + AttributeFunction keySourceAttr = null; } { - ( str = identifier() { ret.append(str); } - ( ( <DOT> { ret.append(token.image); } ( str = identifier() { ret.append(str); } ) ) | - ( lcurly() str = string() { ret.append("{\"").append(str).append("\"}"); } rcurly() ) )* ) - { return new AttributeValue(ret.toString()); } + ( str = identifier() { prefix.append(str); } + ( LOOKAHEAD(2) <DOT> { prefix.append(token.image); } ( str = identifier() { prefix.append(str); } ) )* + ( LOOKAHEAD(3) + ( lcurly() key = string() rcurly() ) | + ( lcurly() keySourceAttr = attributeFunction() rcurly() ) + )? + ( <DOT> { suffix.append(token.image); } ( str = identifier() { suffix.append(str); } ) )* + ) + { + if (key != null) { + return AttributeMapLookupValue.fromKey(prefix.toString(), key, suffix.toString()); + } else if (keySourceAttr != null) { + return AttributeMapLookupValue.fromKeySourceAttribute(prefix.toString(), keySourceAttr.getAttributeName(), suffix.toString()); + } else { + prefix.append(suffix.toString()); + return new AttributeValue(prefix.toString()); + } + } } AttributeFunction attributeFunction() : diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java index 448a8d0e894..2ba991310f5 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java @@ -4,15 +4,19 @@ package com.yahoo.search.dispatch; import com.yahoo.search.dispatch.SearchCluster.Group; import com.yahoo.search.dispatch.SearchCluster.Node; import junit.framework.AssertionFailedError; -import org.hamcrest.Matchers; import org.junit.Test; import java.util.Arrays; import java.util.Optional; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; +/** + * @author ollivir + */ public class LoadBalancerTest { @Test public void requreThatLoadBalancerServesSingleNodeSetups() { @@ -20,22 +24,25 @@ public class LoadBalancerTest { SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1), null, 1, null); LoadBalancer lb = new LoadBalancer(cluster); - Optional<Group> grp = lb.getGroupForQuery(null); + Optional<Group> grp = lb.takeGroupForQuery(null); Group group = grp.orElseGet(() -> { throw new AssertionFailedError("Expected a SearchCluster.Group"); }); - assertThat(group.nodes().size(), Matchers.equalTo(1)); + assertThat(group.nodes().size(), equalTo(1)); } @Test - public void requreThatLoadBalancerIgnoresMultiGroupSetups() { + public void requreThatLoadBalancerServesMultiGroupSetups() { Node n1 = new SearchCluster.Node(0, "test-node1", 0, 0); Node n2 = new SearchCluster.Node(1, "test-node2", 1, 1); SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1, n2), null, 1, null); LoadBalancer lb = new LoadBalancer(cluster); - Optional<Group> grp = lb.getGroupForQuery(null); - assertThat(grp.isPresent(), is(false)); + Optional<Group> grp = lb.takeGroupForQuery(null); + Group group = grp.orElseGet(() -> { + throw new AssertionFailedError("Expected a SearchCluster.Group"); + }); + assertThat(group.nodes().size(), equalTo(1)); } @Test @@ -45,7 +52,7 @@ public class LoadBalancerTest { SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1, n2), null, 2, null); LoadBalancer lb = new LoadBalancer(cluster); - Optional<Group> grp = lb.getGroupForQuery(null); + Optional<Group> grp = lb.takeGroupForQuery(null); assertThat(grp.isPresent(), is(false)); } @@ -58,7 +65,53 @@ public class LoadBalancerTest { SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1, n2, n3, n4), null, 2, null); LoadBalancer lb = new LoadBalancer(cluster); - Optional<Group> grp = lb.getGroupForQuery(null); + Optional<Group> grp = lb.takeGroupForQuery(null); assertThat(grp.isPresent(), is(false)); } + + @Test + public void requreThatLoadBalancerReturnsDifferentGroups() { + Node n1 = new SearchCluster.Node(0, "test-node1", 0, 0); + Node n2 = new SearchCluster.Node(1, "test-node2", 1, 1); + SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1, n2), null, 1, null); + LoadBalancer lb = new LoadBalancer(cluster); + + // get first group + Optional<Group> grp = lb.takeGroupForQuery(null); + Group group = grp.get(); + int id1 = group.id(); + // release allocation + lb.releaseGroup(group); + + // get second group + grp = lb.takeGroupForQuery(null); + group = grp.get(); + assertThat(group.id(), not(equalTo(id1))); + } + + @Test + public void requreThatLoadBalancerReturnsGroupWithShortestQueue() { + Node n1 = new SearchCluster.Node(0, "test-node1", 0, 0); + Node n2 = new SearchCluster.Node(1, "test-node2", 1, 1); + SearchCluster cluster = new SearchCluster(88.0, Arrays.asList(n1, n2), null, 1, null); + LoadBalancer lb = new LoadBalancer(cluster); + + // get first group + Optional<Group> grp = lb.takeGroupForQuery(null); + Group group = grp.get(); + int id1 = group.id(); + + // get second group + grp = lb.takeGroupForQuery(null); + group = grp.get(); + int id2 = group.id(); + assertThat(id2, not(equalTo(id1))); + // release second allocation + lb.releaseGroup(group); + + // get third group + grp = lb.takeGroupForQuery(null); + group = grp.get(); + assertThat(group.id(), equalTo(id2)); + } } diff --git a/container-search/src/test/java/com/yahoo/search/grouping/GroupingValidatorTestCase.java b/container-search/src/test/java/com/yahoo/search/grouping/GroupingValidatorTestCase.java index 82c05c1d995..9723f96af27 100644 --- a/container-search/src/test/java/com/yahoo/search/grouping/GroupingValidatorTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/grouping/GroupingValidatorTestCase.java @@ -7,51 +7,164 @@ import com.yahoo.search.Query; import com.yahoo.search.config.ClusterConfig; import com.yahoo.search.grouping.request.GroupingOperation; import com.yahoo.search.searchchain.Execution; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import java.util.Arrays; import java.util.Collection; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; - /** * @author Simon Thoresen Hult */ public class GroupingValidatorTestCase { + @Rule + public ExpectedException thrown = ExpectedException.none(); + @Test public void requireThatAvailableAttributesDoNotThrow() { - Query query = new Query(); - GroupingRequest req = GroupingRequest.newInstance(query); - req.setRootOperation(GroupingOperation.fromString("all(group(foo) each(output(max(bar))))")); - validateGrouping("myCluster", Arrays.asList("foo", "bar"), query); + validateGrouping(Arrays.asList("foo", "bar"), + "all(group(foo) each(output(max(bar))))");; } @Test public void requireThatUnavailableAttributesThrow() { - Query query = new Query(); - GroupingRequest req = GroupingRequest.newInstance(query); - req.setRootOperation(GroupingOperation.fromString("all(group(foo) each(output(max(bar))))")); - try { - validateGrouping("myCluster", Arrays.asList("foo"), query); - fail("Validator should throw exception because attribute 'bar' is unavailable."); - } catch (UnavailableAttributeException e) { - assertEquals("myCluster", e.getClusterName()); - assertEquals("bar", e.getAttributeName()); - } + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("bar")); + validateGrouping(Arrays.asList("foo"), + "all(group(foo) each(output(max(bar))))"); } @Test public void requireThatEnableFlagPreventsThrow() { + Query query = createQuery("all(group(foo) each(output(max(bar))))"); + query.properties().set(GroupingValidator.PARAM_ENABLED, "false"); + validateGrouping(Arrays.asList("foo"), query); + } + + @Test + public void available_primitive_map_attribute_does_not_throw() { + validateGrouping(Arrays.asList("map.key", "map.value"), + "all(group(map{\"foo\"}) each(output(count())))"); + } + + @Test + public void unavailable_primitive_map_key_attribute_throws() { + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("map.key")); + validateGrouping(Arrays.asList("null"), + "all(group(map{\"foo\"}) each(output(count())))"); + } + + @Test + public void unavailable_primitive_map_value_attribute_throws() { + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("map.value")); + validateGrouping(Arrays.asList("map.key"), + "all(group(map{\"foo\"}) each(output(count())))"); + } + + @Test + public void available_struct_map_attribute_does_not_throw() { + validateGrouping(Arrays.asList("map.key", "map.value.name"), + "all(group(map{\"foo\"}.name) each(output(count())))"); + } + + @Test + public void unavailable_struct_map_key_attribute_throws() { + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("map.key")); + validateGrouping(Arrays.asList("null"), + "all(group(map{\"foo\"}.name) each(output(count())))"); + } + + @Test + public void unavailable_struct_map_value_attribute_throws() { + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("map.value.name")); + validateGrouping(Arrays.asList("map.key"), + "all(group(map{\"foo\"}.name) each(output(count())))"); + } + + @Test + public void available_key_source_attribute_does_not_throw() { + validateGrouping(Arrays.asList("map.key", "map.value", "key_source"), + "all(group(map{attribute(key_source)}) each(output(count())))"); + } + + @Test + public void unavailable_key_source_attribute_throws() { + thrown.expect(UnavailableAttributeException.class); + thrown.expectMessage(createMessage("key_source")); + validateGrouping(Arrays.asList("map.key", "map.value"), + "all(group(map{attribute(key_source)}) each(output(count())))"); + } + + @Test + public void key_source_attribute_with_mismatching_data_type_throws() { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Grouping request references key source attribute 'key_source' with data type 'INT32' " + + "that is different than data type 'STRING' of key attribute 'map.key'"); + + validateGrouping(setupMismatchingKeySourceAttribute(false), + "all(group(map{attribute(key_source)}) each(output(count())))"); + } + + @Test + public void key_source_attribute_with_multi_value_collection_type_throws() { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Grouping request references key source attribute 'key_source' which is not of single value type"); + + validateGrouping(setupMismatchingKeySourceAttribute(true), + "all(group(map{attribute(key_source)}) each(output(count())))"); + } + + private static AttributesConfig setupMismatchingKeySourceAttribute(boolean matchingDataType) { + AttributesConfig.Builder builder = new AttributesConfig.Builder(); + builder.attribute(new AttributesConfig.Attribute.Builder().name("map.key") + .datatype(AttributesConfig.Attribute.Datatype.Enum.STRING)); + builder.attribute(new AttributesConfig.Attribute.Builder().name("map.value")); + builder.attribute(new AttributesConfig.Attribute.Builder().name("key_source") + .datatype(matchingDataType ? AttributesConfig.Attribute.Datatype.Enum.STRING : + AttributesConfig.Attribute.Datatype.Enum.INT32) + .collectiontype(AttributesConfig.Attribute.Collectiontype.Enum.ARRAY)); + return new AttributesConfig(builder); + } + + private static String createMessage(String attributeName) { + return "Grouping request references attribute '" + attributeName + "' which is not available in cluster 'myCluster'."; + } + + private static Query createQuery(String groupingExpression) { Query query = new Query(); GroupingRequest req = GroupingRequest.newInstance(query); - req.setRootOperation(GroupingOperation.fromString("all(group(foo) each(output(max(bar))))")); - query.properties().set(GroupingValidator.PARAM_ENABLED, "false"); - validateGrouping("myCluster", Arrays.asList("foo"), query); + req.setRootOperation(GroupingOperation.fromString(groupingExpression)); + return query; + } + + private static AttributesConfig createAttributesConfig(Collection<String> attributeNames) { + AttributesConfig.Builder builder = new AttributesConfig.Builder(); + for (String attributeName : attributeNames) { + builder.attribute(new AttributesConfig.Attribute.Builder() + .name(attributeName)); + } + return new AttributesConfig(builder); + } + + private static void validateGrouping(Collection<String> attributeNames, String groupingExpression) { + validateGrouping("myCluster", createAttributesConfig(attributeNames), createQuery(groupingExpression)); + } + + private static void validateGrouping(AttributesConfig attributesCfg, String groupingExpression) { + validateGrouping("myCluster", attributesCfg, createQuery(groupingExpression)); } - private static void validateGrouping(String clusterName, Collection<String> attributeNames, Query query) { + private static void validateGrouping(Collection<String> attributeNames, Query query) { + validateGrouping("myCluster", createAttributesConfig(attributeNames), query); + } + + private static void validateGrouping(String clusterName, AttributesConfig attributesConfig, Query query) { QrSearchersConfig.Builder qrsConfig = new QrSearchersConfig.Builder().searchcluster( new QrSearchersConfig.Searchcluster.Builder() .indexingmode(QrSearchersConfig.Searchcluster.Indexingmode.Enum.REALTIME) @@ -59,15 +172,10 @@ public class GroupingValidatorTestCase { ClusterConfig.Builder clusterConfig = new ClusterConfig.Builder(). clusterId(0). clusterName("test"); - AttributesConfig.Builder attributesConfig = new AttributesConfig.Builder(); - for (String attributeName : attributeNames) { - attributesConfig.attribute(new AttributesConfig.Attribute.Builder() - .name(attributeName)); - } new Execution( new GroupingValidator(new QrSearchersConfig(qrsConfig), - new ClusterConfig(clusterConfig), - new AttributesConfig(attributesConfig)), + new ClusterConfig(clusterConfig), + attributesConfig), Execution.Context.createContextStub()).search(query); } } diff --git a/container-search/src/test/java/com/yahoo/search/grouping/request/parser/GroupingParserTestCase.java b/container-search/src/test/java/com/yahoo/search/grouping/request/parser/GroupingParserTestCase.java index 2c43873036e..afbad73f982 100644 --- a/container-search/src/test/java/com/yahoo/search/grouping/request/parser/GroupingParserTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/grouping/request/parser/GroupingParserTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.search.grouping.request.parser; import com.yahoo.search.grouping.request.AllOperation; +import com.yahoo.search.grouping.request.AttributeMapLookupValue; import com.yahoo.search.grouping.request.EachOperation; import com.yahoo.search.grouping.request.GroupingOperation; import com.yahoo.search.query.parser.Parsable; @@ -24,12 +25,6 @@ import static org.junit.Assert.fail; */ public class GroupingParserTestCase { - // -------------------------------------------------------------------------------- - // - // Tests. - // - // -------------------------------------------------------------------------------- - @Test public void requireThatMathAllowsWhitespace() { for (String op : Arrays.asList("+", " +", " + ", "+ ", @@ -448,6 +443,46 @@ public class GroupingParserTestCase { assertParse("all(group(my.little{key }))", "all(group(my.little{\"key\"}))"); assertParse("all(group(my.little{\"key\"}))", "all(group(my.little{\"key\"}))"); assertParse("all(group(my.little{\"key{}%\"}))", "all(group(my.little{\"key{}%\"}))"); + assertParse("all(group(my.little{key}.name))", "all(group(my.little{\"key\"}.name))"); + assertParse("all(group(my.little{key }.name))", "all(group(my.little{\"key\"}.name))"); + assertParse("all(group(my.little{\"key\"}.name))", "all(group(my.little{\"key\"}.name))"); + assertParse("all(group(my.little{\"key{}%\"}.name))", "all(group(my.little{\"key{}%\"}.name))"); + + assertAttributeMapLookup("all(group(my_map{\"my_key\"}))", + "my_map.key", "my_map.value", "my_key", ""); + assertAttributeMapLookup("all(group(my_map{\"my_key\"}.name))", + "my_map.key", "my_map.value.name", "my_key", ""); + assertAttributeMapLookup("all(group(my.map{\"my_key\"}))", + "my.map.key", "my.map.value", "my_key", ""); + } + + @Test + public void testMapSyntaxWithKeySourceAttribute() { + assertAttributeMapLookup("all(group(my_map{attribute(my_attr)}))", + "my_map.key", "my_map.value", "", "my_attr"); + assertAttributeMapLookup("all(group(my_map{attribute(my_attr)}.name))", + "my_map.key", "my_map.value.name", "", "my_attr"); + assertAttributeMapLookup("all(group(my.map{attribute(my_attr.name)}))", + "my.map.key", "my.map.value", "", "my_attr.name"); + + assertIllegalArgument("all(group(my_map{attribute(\"my_attr\")}))", + "Encountered \" <STRING> \"\\\"my_attr\\\" \"\" at line 1, column 28"); + + } + + private static void assertAttributeMapLookup(String request, + String expKeyAttribute, + String expValueAttribute, + String expKey, + String expKeySourceAttribute) { + assertParse(request, request); + List<GroupingOperation> operations = GroupingOperation.fromStringAsList(request); + assertEquals(1, operations.size()); + AttributeMapLookupValue mapLookup = (AttributeMapLookupValue)operations.get(0).getGroupBy(); + assertEquals(expKeyAttribute, mapLookup.getKeyAttribute()); + assertEquals(expValueAttribute, mapLookup.getValueAttribute()); + assertEquals(expKey, mapLookup.getKey()); + assertEquals(expKeySourceAttribute, mapLookup.getKeySourceAttribute()); } @Test diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ServiceConvergence.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ServiceConvergence.java index 8a90224083b..6cfdc9fadc8 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ServiceConvergence.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ServiceConvergence.java @@ -1,35 +1,63 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.configserver; +import com.google.common.collect.ImmutableList; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.HostName; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; +import java.util.List; +import java.util.OptionalLong; + /** * Service convergence status for an application. * * @author mpolden + * @author jonmv */ public class ServiceConvergence { private final ApplicationId application; private final ZoneId zone; private final boolean converged; + private final long wantedGeneration; + private final List<Status> services; - public ServiceConvergence(ApplicationId application, ZoneId zone, boolean converged) { + public ServiceConvergence(ApplicationId application, ZoneId zone, boolean converged, + long wantedGeneration, List<Status> services) { this.application = application; this.zone = zone; this.converged = converged; + this.wantedGeneration = wantedGeneration; + this.services = ImmutableList.copyOf(services); } - public ApplicationId application() { - return application; - } + public ApplicationId application() { return application; } + public ZoneId zone() { return zone; } + public boolean converged() { return converged; } + public long wantedGeneration() { return wantedGeneration; } + public List<Status> services() { return services; } - public ZoneId zone() { - return zone; - } - public boolean converged() { - return converged; + /** Immutable class detailing the config status of a particular service for an application. */ + public static class Status { + private final HostName host; + private final long port; + private final String type; + private final long currentGeneration; + + public Status(HostName host, long port, String type, long currentGeneration) { + this.host = host; + this.port = port; + this.type = type; + this.currentGeneration = currentGeneration; + } + + public HostName host() { return host; } + public long port() { return port; } + public String type() { return type; } + public long currentGeneration() { return currentGeneration; } + } + } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/MockOrganization.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/MockOrganization.java index 96ee9ecd052..8efbde52d4a 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/MockOrganization.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/MockOrganization.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.controller.api.integration.organization; import com.google.inject.Inject; +import com.yahoo.component.AbstractComponent; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; import java.net.URI; @@ -18,7 +19,7 @@ import java.util.concurrent.atomic.AtomicLong; /** * @author jvenstad */ -public class MockOrganization implements Organization { +public class MockOrganization extends AbstractComponent implements Organization { private final Clock clock; private final AtomicLong counter = new AtomicLong(); @@ -89,45 +90,58 @@ public class MockOrganization implements Organization { @Override public URI issueCreationUri(PropertyId propertyId) { - return URI.create("www.issues.tld/" + propertyId.id()); + return properties.getOrDefault(propertyId, new PropertyInfo()).issueUrl; } @Override public URI contactsUri(PropertyId propertyId) { - return URI.create("www.contacts.tld/" + propertyId.id()); + return properties.getOrDefault(propertyId, new PropertyInfo()).contactsUrl; } @Override public URI propertyUri(PropertyId propertyId) { - return URI.create("www.properties.tld/" + propertyId.id()); + return properties.getOrDefault(propertyId, new PropertyInfo()).propertyUrl; } public Map<IssueId, MockIssue> issues() { return Collections.unmodifiableMap(issues); } - public void close(IssueId issueId) { + public MockOrganization close(IssueId issueId) { issues.get(issueId).open = false; touch(issueId); + return this; } - public void setDefaultAssigneeFor(PropertyId propertyId, User defaultAssignee) { - properties.get(propertyId).defaultAssignee = defaultAssignee; + public MockOrganization setContactsFor(PropertyId propertyId, List<List<User>> contacts) { + properties.get(propertyId).contacts = contacts; + return this; } - public void setContactsFor(PropertyId propertyId, List<List<User>> contacts) { - properties.get(propertyId).contacts = contacts; + public MockOrganization setPropertyUrl(PropertyId propertyId, URI url) { + properties.get(propertyId).propertyUrl = url; + return this; + } + + public MockOrganization setContactsUrl(PropertyId propertyId, URI url) { + properties.get(propertyId).contactsUrl = url; + return this; } - public void addProperty(PropertyId propertyId) { + public MockOrganization setIssueUrl(PropertyId propertyId, URI url) { + properties.get(propertyId).issueUrl = url; + return this; + } + + public MockOrganization addProperty(PropertyId propertyId) { properties.put(propertyId, new PropertyInfo()); + return this; } private void touch(IssueId issueId) { issues.get(issueId).updated = clock.instant(); } - public class MockIssue { private Issue issue; @@ -148,11 +162,13 @@ public class MockOrganization implements Organization { } - private class PropertyInfo { private User defaultAssignee; private List<List<User>> contacts = Collections.emptyList(); + private URI issueUrl; + private URI contactsUrl; + private URI propertyUrl; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java index e984edca7db..a3a4e99c38d 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java @@ -27,6 +27,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.configserver.NoInstance import com.yahoo.vespa.hosted.controller.api.integration.configserver.PrepareResponse; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationStore; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ArtifactRepository; +import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.api.integration.dns.NameService; import com.yahoo.vespa.hosted.controller.api.integration.dns.Record; import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordData; @@ -38,7 +39,6 @@ import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Deployment; -import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.application.JobList; import com.yahoo.vespa.hosted.controller.application.JobStatus; import com.yahoo.vespa.hosted.controller.application.JobStatus.JobRun; @@ -57,6 +57,8 @@ import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.time.Clock; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -115,9 +117,13 @@ public class ApplicationController { this.rotationRepository = new RotationRepository(rotationsConfig, this, curator); this.deploymentTrigger = new DeploymentTrigger(controller, buildService, clock); + Instant start = clock.instant(); + int count = 0; for (Application application : curator.readApplications()) { lockIfPresent(application.id(), this::store); + count++; } + log.log(Level.INFO, String.format("Wrote %d applications in %s", count, Duration.between(start, clock.instant()))); } /** Returns the application with the given id, or null if it is not present */ @@ -240,7 +246,9 @@ public class ApplicationController { */ public Application createApplication(ApplicationId id, Optional<NToken> token) { if ( ! (id.instance().isDefault())) // TODO: Support instances properly - throw new UnsupportedOperationException("Only the instance name 'default' is supported at the moment"); + throw new IllegalArgumentException("Only the instance name 'default' is supported at the moment"); + if (id.instance().isTester()) + throw new IllegalArgumentException("'" + id + "' is a tester application!"); try (Lock lock = lock(id)) { // Validate only application names which do not already exist. if (asList(id.tenant()).stream().noneMatch(application -> application.id().application().equals(id.application()))) @@ -270,9 +278,13 @@ public class ApplicationController { /** Deploys an application. If the application does not exist it is created. */ // TODO: Get rid of the options arg + // TODO jvenstad: Split this, and choose between deployDirectly and deploy in handler, excluding internally built from the latter. public ActivateResult deploy(ApplicationId applicationId, ZoneId zone, Optional<ApplicationPackage> applicationPackageFromDeployer, DeployOptions options) { + if (applicationId.instance().isTester()) + throw new IllegalArgumentException("'" + applicationId + "' is a tester application!"); + try (Lock lock = lock(applicationId)) { LockedApplication application = get(applicationId) .map(app -> new LockedApplication(app, lock)) @@ -375,7 +387,7 @@ public class ApplicationController { /** Assembles and deploys a tester application to the given zone. */ public ActivateResult deployTester(ApplicationId tester, ApplicationPackage applicationPackage, ZoneId zone, DeployOptions options) { - if ( ! tester.instance().value().endsWith("-t")) + if ( ! tester.instance().isTester()) throw new IllegalArgumentException("'" + tester + "' is not a tester application!"); return deploy(tester, applicationPackage, zone, options, Collections.emptySet(), Collections.emptySet()); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java index ee0a6875796..794b248b27a 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java @@ -12,8 +12,8 @@ import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; import com.yahoo.vespa.hosted.controller.api.integration.BuildService; -import com.yahoo.vespa.hosted.controller.api.integration.RunDataStore; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService; +import com.yahoo.vespa.hosted.controller.api.integration.RunDataStore; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory; import com.yahoo.vespa.hosted.controller.api.integration.chef.Chef; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServer; @@ -79,7 +79,6 @@ public class Controller extends AbstractComponent { private final ConfigServer configServer; private final MetricsService metricsService; private final Chef chef; - private final Organization organization; private final AthenzClientFactory athenzClientFactory; /** @@ -117,7 +116,6 @@ public class Controller extends AbstractComponent { this.curator = Objects.requireNonNull(curator, "Curator cannot be null"); this.gitHub = Objects.requireNonNull(gitHub, "GitHub cannot be null"); this.entityService = Objects.requireNonNull(entityService, "EntityService cannot be null"); - this.organization = Objects.requireNonNull(organization, "Organization cannot be null"); this.globalRoutingService = Objects.requireNonNull(globalRoutingService, "GlobalRoutingService cannot be null"); this.zoneRegistry = Objects.requireNonNull(zoneRegistry, "ZoneRegistry cannot be null"); this.configServer = Objects.requireNonNull(configServer, "ConfigServer cannot be null"); @@ -136,7 +134,7 @@ public class Controller extends AbstractComponent { Objects.requireNonNull(routingGenerator, "RoutingGenerator cannot be null"), Objects.requireNonNull(buildService, "BuildService cannot be null"), clock); - tenantController = new TenantController(this, curator, athenzClientFactory); + tenantController = new TenantController(this, curator, athenzClientFactory, organization); // Record the version of this controller curator().writeControllerVersion(this.hostname(), Vtag.currentVersion); @@ -289,10 +287,6 @@ public class Controller extends AbstractComponent { return chef; } - public Organization organization() { - return organization; - } - public CuratorDb curator() { return curator; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedTenant.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedTenant.java new file mode 100644 index 00000000000..cb3f50d08c7 --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedTenant.java @@ -0,0 +1,76 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller; + +import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.athenz.api.AthenzDomain; +import com.yahoo.vespa.curator.Lock; +import com.yahoo.vespa.hosted.controller.api.identifiers.Property; +import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; +import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; + +import java.util.Objects; +import java.util.Optional; +import java.util.function.Consumer; + +/** + * A tenant that has been locked for modification. Provides methods for modifying a tenant's fields. + * + * @author mpolden + */ +public class LockedTenant { + + private final Lock lock; + private final TenantName name; + private final AthenzDomain domain; + private final Property property; + private final Optional<PropertyId> propertyId; + private final Optional<Contact> contact; + + /** + * Should never be constructed directly. + * + * Use {@link TenantController#lockIfPresent(TenantName, Consumer)} or + * {@link TenantController#lockOrThrow(TenantName, Consumer)} + */ + LockedTenant(AthenzTenant tenant, Lock lock) { + this(lock, tenant.name(), tenant.domain(), tenant.property(), tenant.propertyId(), tenant.contact()); + } + + private LockedTenant(Lock lock, TenantName name, AthenzDomain domain, Property property, + Optional<PropertyId> propertyId, Optional<Contact> contact) { + this.lock = Objects.requireNonNull(lock, "lock must be non-null"); + this.name = Objects.requireNonNull(name, "name must be non-null"); + this.domain = Objects.requireNonNull(domain, "domain must be non-null"); + this.property = Objects.requireNonNull(property, "property must be non-null"); + this.propertyId = Objects.requireNonNull(propertyId, "propertyId must be non-null"); + this.contact = Objects.requireNonNull(contact, "contact must be non-null"); + } + + /** Returns a read-only copy of this */ + public AthenzTenant get() { + return new AthenzTenant(name, domain, property, propertyId, contact); + } + + public LockedTenant with(AthenzDomain domain) { + return new LockedTenant(lock, name, domain, property, propertyId, contact); + } + + public LockedTenant with(Property property) { + return new LockedTenant(lock, name, domain, property, propertyId, contact); + } + + public LockedTenant with(PropertyId propertyId) { + return new LockedTenant(lock, name, domain, property, Optional.of(propertyId), contact); + } + + public LockedTenant with(Contact contact) { + return new LockedTenant(lock, name, domain, property, propertyId, Optional.of(contact)); + } + + @Override + public String toString() { + return "tenant '" + name + "'"; + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java index 228ca01e764..20847f904aa 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java @@ -10,17 +10,24 @@ import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.hosted.controller.api.identifiers.UserId; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory; import com.yahoo.vespa.hosted.controller.api.integration.athenz.ZmsClient; +import com.yahoo.vespa.hosted.controller.api.integration.organization.Organization; +import com.yahoo.vespa.hosted.controller.api.integration.organization.User; import com.yahoo.vespa.hosted.controller.persistence.CuratorDb; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; import com.yahoo.vespa.hosted.controller.tenant.Tenant; import com.yahoo.vespa.hosted.controller.tenant.UserTenant; import java.time.Duration; +import java.time.Instant; import java.util.Comparator; import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Consumer; +import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -34,19 +41,20 @@ public class TenantController { private static final Logger log = Logger.getLogger(TenantController.class.getName()); - /** The controller owning this */ private final Controller controller; - - /** For persistence */ private final CuratorDb curator; - private final AthenzClientFactory athenzClientFactory; + private final Organization organization; + + public TenantController(Controller controller, CuratorDb curator, AthenzClientFactory athenzClientFactory, Organization organization) { + this.controller = Objects.requireNonNull(controller, "controller must be non-null"); + this.curator = Objects.requireNonNull(curator, "curator must be non-null"); + this.athenzClientFactory = Objects.requireNonNull(athenzClientFactory, "athenzClientFactory must be non-null"); + this.organization = Objects.requireNonNull(organization, "organization must be non-null"); - public TenantController(Controller controller, CuratorDb curator, AthenzClientFactory athenzClientFactory) { - this.controller = controller; - this.curator = curator; - this.athenzClientFactory = athenzClientFactory; // Write all tenants to ensure persisted data uses latest serialization format + Instant start = controller.clock().instant(); + int count = 0; for (Tenant tenant : curator.readTenants()) { try (Lock lock = lock(tenant.name())) { if (tenant instanceof AthenzTenant) { @@ -57,7 +65,9 @@ public class TenantController { throw new IllegalArgumentException("Unknown tenant type: " + tenant.getClass().getSimpleName()); } } + count++; } + log.log(Level.INFO, String.format("Wrote %d tenants in %s", count, Duration.between(start, controller.clock().instant()))); } /** Returns a list of all known tenants sorted by name */ @@ -79,12 +89,51 @@ public class TenantController { } } + /** Find contact information for given tenant */ + // TODO: Move this to ContactInformationMaintainer + public Optional<Contact> findContact(AthenzTenant tenant) { + if (!tenant.propertyId().isPresent()) { + return Optional.empty(); + } + List<List<String>> persons = organization.contactsFor(tenant.propertyId().get()) + .stream() + .map(personList -> personList.stream() + .map(User::displayName) + .collect(Collectors.toList())) + .collect(Collectors.toList()); + return Optional.of(new Contact(organization.contactsUri(tenant.propertyId().get()), + organization.propertyUri(tenant.propertyId().get()), + organization.issueCreationUri(tenant.propertyId().get()), + persons)); + } + + /** + * Lock a tenant for modification and apply action. Only valid for Athenz tenants as it's the only type that + * accepts modification. + */ + public void lockIfPresent(TenantName name, Consumer<LockedTenant> action) { + try (Lock lock = lock(name)) { + athenzTenant(name).map(tenant -> new LockedTenant(tenant, lock)).ifPresent(action); + } + } + + /** Lock a tenant for modification and apply action. Throws if the tenant does not exist */ + public void lockOrThrow(TenantName name, Consumer<LockedTenant> action) { + try (Lock lock = lock(name)) { + action.accept(new LockedTenant(requireAthenzTenant(name), lock)); + } + } + + /** Replace and store any previous version of given tenant */ + public void store(LockedTenant tenant) { + curator.writeTenant(tenant.get()); + } + /** Create an user tenant with given username */ public void create(UserTenant tenant) { try (Lock lock = lock(tenant.name())) { requireNonExistent(tenant.name()); curator.writeTenant(tenant); - log.info("Created " + tenant); } } @@ -103,7 +152,6 @@ public class TenantController { } athenzClientFactory.createZmsClientWithAuthorizedServiceToken(token).createTenant(domain); curator.writeTenant(tenant); - log.info("Created " + tenant); } } @@ -129,14 +177,29 @@ public class TenantController { return curator.readAthenzTenant(name); } - /** Update Athenz tenant */ - public void updateTenant(AthenzTenant updatedTenant, NToken token) { - try (Lock lock = lock(updatedTenant.name())) { - requireExists(updatedTenant.name()); - updateAthenzDomain(updatedTenant, token); - curator.writeTenant(updatedTenant); - log.info("Updated " + updatedTenant); - } + /** Returns Athenz tenant with name or throws if no such tenant exists */ + public AthenzTenant requireAthenzTenant(TenantName name) { + return athenzTenant(name).orElseThrow(() -> new IllegalArgumentException("Tenant '" + name + "' not found")); + } + + /** Update Athenz domain for tenant. Returns the updated tenant which must be explicitly stored */ + public LockedTenant withDomain(LockedTenant tenant, AthenzDomain newDomain, NToken token) { + AthenzDomain existingDomain = tenant.get().domain(); + if (existingDomain.equals(newDomain)) return tenant; + Optional<Tenant> existingTenantWithNewDomain = tenantIn(newDomain); + if (existingTenantWithNewDomain.isPresent()) + throw new IllegalArgumentException("Could not set domain of " + tenant + " to '" + newDomain + + "':" + existingTenantWithNewDomain.get() + " already has this domain"); + + ZmsClient zmsClient = athenzClientFactory.createZmsClientWithAuthorizedServiceToken(token); + zmsClient.createTenant(newDomain); + List<Application> applications = controller.applications().asList(tenant.get().name()); + applications.forEach(a -> zmsClient.addApplication(newDomain, new com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId(a.id().application().value()))); + applications.forEach(a -> zmsClient.deleteApplication(existingDomain, new com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId(a.id().application().value()))); + zmsClient.deleteTenant(existingDomain); + log.info("Set Athenz domain for '" + tenant + "' from '" + existingDomain + "' to '" + newDomain + "'"); + + return tenant.with(newDomain); } /** Delete an user tenant */ @@ -160,28 +223,6 @@ public class TenantController { + "': This tenant has active applications"); } curator.removeTenant(name); - log.info("Deleted " + name); - } - - private void updateAthenzDomain(AthenzTenant updatedTenant, NToken token) { - Optional<AthenzTenant> existingTenant = athenzTenant(updatedTenant.name()); - if ( ! existingTenant.isPresent()) return; - - AthenzDomain existingDomain = existingTenant.get().domain(); - AthenzDomain newDomain = updatedTenant.domain(); - if (existingDomain.equals(newDomain)) return; - Optional<Tenant> existingTenantWithNewDomain = tenantIn(newDomain); - if (existingTenantWithNewDomain.isPresent()) - throw new IllegalArgumentException("Could not set domain of " + updatedTenant + " to '" + newDomain + - "':" + existingTenantWithNewDomain.get() + " already has this domain"); - - ZmsClient zmsClient = athenzClientFactory.createZmsClientWithAuthorizedServiceToken(token); - zmsClient.createTenant(newDomain); - List<Application> applications = controller.applications().asList(existingTenant.get().name()); - applications.forEach(a -> zmsClient.addApplication(newDomain, new com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId(a.id().application().value()))); - applications.forEach(a -> zmsClient.deleteApplication(existingDomain, new com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId(a.id().application().value()))); - zmsClient.deleteTenant(existingDomain); - log.info("Updated Athens domain for " + updatedTenant + " from " + existingDomain + " to " + newDomain); } private void requireNonExistent(TenantName name) { @@ -193,12 +234,6 @@ public class TenantController { } } - private void requireExists(TenantName name) { - if (!tenant(name).isPresent()) { - throw new IllegalArgumentException("Tenant '" + name + "' does not exist"); - } - } - /** * Returns a lock which provides exclusive rights to changing this tenant. * Any operation which stores a tenant need to first acquire this lock, then read, modify diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java index b3dd46a3e65..703a198be1e 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.controller.application; import java.util.Objects; import java.util.Optional; +import java.util.OptionalLong; /** * An application package version, identified by a source revision and a build number. @@ -16,21 +17,21 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> { * Used in cases where application version cannot be determined, such as manual deployments (e.g. in dev * environment) */ - public static final ApplicationVersion unknown = new ApplicationVersion(Optional.empty(), Optional.empty()); + public static final ApplicationVersion unknown = new ApplicationVersion(Optional.empty(), OptionalLong.empty()); // This never changes and is only used to create a valid semantic version number, as required by application bundles private static final String majorVersion = "1.0"; private final Optional<SourceRevision> source; - private final Optional<Long> buildNumber; + private final OptionalLong buildNumber; - private ApplicationVersion(Optional<SourceRevision> source, Optional<Long> buildNumber) { + private ApplicationVersion(Optional<SourceRevision> source, OptionalLong buildNumber) { Objects.requireNonNull(source, "source cannot be null"); Objects.requireNonNull(buildNumber, "buildNumber cannot be null"); if (source.isPresent() != buildNumber.isPresent()) { throw new IllegalArgumentException("both buildNumber and source must be set together"); } - if (buildNumber.isPresent() && buildNumber.get() <= 0) { + if (buildNumber.isPresent() && buildNumber.getAsLong() <= 0) { throw new IllegalArgumentException("buildNumber must be > 0"); } this.source = source; @@ -39,7 +40,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> { /** Create an application package version from a completed build */ public static ApplicationVersion from(SourceRevision source, long buildNumber) { - return new ApplicationVersion(Optional.of(source), Optional.of(buildNumber)); + return new ApplicationVersion(Optional.of(source), OptionalLong.of(buildNumber)); } /** Returns an unique identifier for this version or "unknown" if version is not known */ @@ -47,7 +48,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> { if (isUnknown()) { return "unknown"; } - return String.format("%s.%d-%s", majorVersion, buildNumber.get(), abbreviateCommit(source.get().commit())); + return String.format("%s.%d-%s", majorVersion, buildNumber.getAsLong(), abbreviateCommit(source.get().commit())); } /** @@ -57,7 +58,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> { public Optional<SourceRevision> source() { return source; } /** Returns the build number that built this version */ - public Optional<Long> buildNumber() { return buildNumber; } + public OptionalLong buildNumber() { return buildNumber; } /** Returns whether this is unknown */ public boolean isUnknown() { @@ -93,6 +94,6 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> { if ( ! buildNumber().isPresent() || ! o.buildNumber().isPresent()) return Boolean.compare(buildNumber().isPresent(), o.buildNumber.isPresent()); // Application package hash sorts first - return buildNumber().get().compareTo(o.buildNumber().get()); + return Long.compare(buildNumber().getAsLong(), o.buildNumber().getAsLong()); } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java index 0a062427a8a..a2433d223dc 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java @@ -68,12 +68,12 @@ public class Deployment { public DeploymentActivity activity() { return activity; } /** Returns information about the clusters allocated to this */ - public Map<Id, ClusterInfo> clusterInfo() { + public Map<Id, ClusterInfo> clusterInfo() { return clusterInfo; } /** Returns utilization of the clusters allocated to this */ - public Map<Id, ClusterUtilization> clusterUtils() { + public Map<Id, ClusterUtilization> clusterUtils() { return clusterUtils; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzTrustStoreConfigurator.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzTrustStoreConfigurator.java index 909104d1731..e805332429b 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzTrustStoreConfigurator.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzTrustStoreConfigurator.java @@ -4,17 +4,13 @@ package com.yahoo.vespa.hosted.controller.athenz.filter; import com.google.inject.Inject; import com.yahoo.jdisc.http.ssl.SslTrustStoreConfigurator; import com.yahoo.jdisc.http.ssl.SslTrustStoreContext; -import com.yahoo.vespa.athenz.tls.KeyStoreBuilder; -import com.yahoo.vespa.athenz.tls.KeyStoreType; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; import com.yahoo.vespa.hosted.controller.athenz.config.AthenzConfig; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.cert.CertificateException; /** * Load trust store with Athenz CA certificates @@ -27,10 +23,10 @@ public class AthenzTrustStoreConfigurator implements SslTrustStoreConfigurator { @Inject public AthenzTrustStoreConfigurator(AthenzConfig config) { - this.trustStore = createTrustStore(new File(config.athenzCaTrustStore())); + this.trustStore = createTrustStore(Paths.get(config.athenzCaTrustStore())); } - private static KeyStore createTrustStore(File trustStoreFile) { + private static KeyStore createTrustStore(Path trustStoreFile) { return KeyStoreBuilder.withType(KeyStoreType.JKS) .fromFile(trustStoreFile, "changeit".toCharArray()) .build(); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java index b49722f2f2d..de2fe58bcbf 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java @@ -238,7 +238,7 @@ public class InternalStepRunner implements StepRunner { ApplicationVersion application = setTheStage ? versions.sourceApplication().orElse(versions.targetApplication()) : versions.targetApplication(); logger.log("Checking installation of " + platform + " and " + application.id() + " ..."); - if (nodesConverged(id.application(), id.type(), platform, logger) && servicesConverged(id.application(), id.type())) { + if (nodesConverged(id.application(), id.type(), platform, logger) && servicesConverged(id.application(), id.type(), logger)) { logger.log("Installation succeeded!"); return Optional.of(running); } @@ -260,7 +260,7 @@ public class InternalStepRunner implements StepRunner { } logger.log("Checking installation of tester container ..."); - if (servicesConverged(JobController.testerOf(id.application()), id.type())) { + if (servicesConverged(JobController.testerOf(id.application()), id.type(), logger)) { logger.log("Tester container successfully installed!"); return Optional.of(running); } @@ -291,11 +291,21 @@ public class InternalStepRunner implements StepRunner { && node.rebootGeneration() == node.wantedRebootGeneration()); } - private boolean servicesConverged(ApplicationId id, JobType type) { - // TODO jvenstad: Print information for each host. - return controller.configServer().serviceConvergence(new DeploymentId(id, type.zone(controller.system()))) - .map(ServiceConvergence::converged) - .orElse(false); + private boolean servicesConverged(ApplicationId id, JobType type, DualLogger logger) { + Optional<ServiceConvergence> convergence = controller.configServer().serviceConvergence(new DeploymentId(id, type.zone(controller.system()))); + if ( ! convergence.isPresent()) { + logger.log("Config status not currently available -- will retry."); + return false; + } + logger.log("Wanted config generation is " + convergence.get().wantedGeneration()); + for (ServiceConvergence.Status serviceStatus : convergence.get().services()) + if (serviceStatus.currentGeneration() != convergence.get().wantedGeneration()) + logger.log(String.format("%70s: %11s on port %4d has %s", + serviceStatus.host().value(), + serviceStatus.type(), + serviceStatus.port(), + serviceStatus.currentGeneration() == -1 ? "(unknown)" : Long.toString(serviceStatus.currentGeneration()))); + return convergence.get().converged(); } private Optional<RunStatus> startTests(RunId id, DualLogger logger) { @@ -491,7 +501,7 @@ public class InternalStepRunner implements StepRunner { " </filtering>\n" + " </http>\n" + "\n" + - " <nodes count=\"1\" flavor=\"d-2-8-50\" />\n" + + " <nodes count=\"1\" flavor=\"d-1-4-50\" />\n" + " </container>\n" + "</services>\n"; diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainer.java new file mode 100644 index 00000000000..aaa9c09074b --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainer.java @@ -0,0 +1,44 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.maintenance; + +import com.yahoo.log.LogLevel; +import com.yahoo.vespa.hosted.controller.Controller; +import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Tenant; +import com.yahoo.yolean.Exceptions; + +import java.time.Duration; +import java.util.logging.Logger; + +/** + * Periodically fetch and store contact information for tenants. + * + * @author mpolden + */ +public class ContactInformationMaintainer extends Maintainer { + + private static final Logger log = Logger.getLogger(ContactInformationMaintainer.class.getName()); + + public ContactInformationMaintainer(Controller controller, Duration interval, JobControl jobControl) { + super(controller, interval, jobControl); + } + + @Override + protected void maintain() { + for (Tenant t : controller().tenants().asList()) { + if (!(t instanceof AthenzTenant)) continue; // No contact information for non-Athenz tenants + AthenzTenant tenant = (AthenzTenant) t; + if (!tenant.propertyId().isPresent()) continue; // Can only update contact information if property ID is known + try { + controller().tenants().findContact(tenant).ifPresent(contact -> { + controller().tenants().lockIfPresent(t.name(), lockedTenant -> controller().tenants().store(lockedTenant.with(contact))); + }); + } catch (Exception e) { + log.log(LogLevel.WARNING, "Failed to update contact information for " + tenant + ": " + + Exceptions.toMessageString(e) + ". Retrying in " + + maintenanceInterval()); + } + } + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java index 2c65ea0e3cb..8256d9ca182 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java @@ -5,13 +5,11 @@ import com.yahoo.component.AbstractComponent; import com.yahoo.jdisc.Metric; import com.yahoo.vespa.hosted.controller.Controller; import com.yahoo.vespa.hosted.controller.api.integration.chef.Chef; -import com.yahoo.vespa.hosted.controller.api.integration.deployment.TesterCloud; import com.yahoo.vespa.hosted.controller.api.integration.dns.NameService; import com.yahoo.vespa.hosted.controller.api.integration.noderepository.NodeRepositoryClientInterface; import com.yahoo.vespa.hosted.controller.api.integration.organization.DeploymentIssues; import com.yahoo.vespa.hosted.controller.api.integration.organization.OwnershipIssues; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; -import com.yahoo.vespa.hosted.controller.deployment.InternalStepRunner; import com.yahoo.vespa.hosted.controller.maintenance.config.MaintainerConfig; import com.yahoo.vespa.hosted.controller.persistence.CuratorDb; @@ -47,6 +45,7 @@ public class ControllerMaintenance extends AbstractComponent { private final List<OsUpgrader> osUpgraders; private final OsVersionStatusUpdater osVersionStatusUpdater; private final JobRunner jobRunner; + private final ContactInformationMaintainer contactInformationMaintainer; @SuppressWarnings("unused") // instantiated by Dependency Injection public ControllerMaintenance(MaintainerConfig maintainerConfig, Controller controller, CuratorDb curator, @@ -71,6 +70,7 @@ public class ControllerMaintenance extends AbstractComponent { jobRunner = new JobRunner(controller, Duration.ofSeconds(30), jobControl); osUpgraders = osUpgraders(controller, jobControl); osVersionStatusUpdater = new OsVersionStatusUpdater(controller, maintenanceInterval, jobControl); + contactInformationMaintainer = new ContactInformationMaintainer(controller, Duration.ofHours(12), jobControl); } public Upgrader upgrader() { return upgrader; } @@ -96,6 +96,7 @@ public class ControllerMaintenance extends AbstractComponent { osUpgraders.forEach(Maintainer::deconstruct); osVersionStatusUpdater.deconstruct(); jobRunner.deconstruct(); + contactInformationMaintainer.deconstruct(); } /** Create one OS upgrader per cloud found in the zone registry of controller */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java index 763d26834e6..58e0b8dbeec 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java @@ -213,7 +213,7 @@ public class ApplicationSerializer { private void toSlime(ApplicationVersion applicationVersion, Cursor object) { if (applicationVersion.buildNumber().isPresent() && applicationVersion.source().isPresent()) { - object.setLong(applicationBuildNumberField, applicationVersion.buildNumber().get()); + object.setLong(applicationBuildNumberField, applicationVersion.buildNumber().getAsLong()); toSlime(applicationVersion.source().get(), object.setObject(sourceRevisionField)); } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java index d55dc791462..28400b85306 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.controller.persistence; import com.yahoo.config.provision.TenantName; +import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; import com.yahoo.slime.Slime; @@ -11,8 +12,12 @@ import com.yahoo.vespa.config.SlimeUtils; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; import com.yahoo.vespa.hosted.controller.tenant.UserTenant; +import java.net.URI; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; /** @@ -26,6 +31,12 @@ public class TenantSerializer { private static final String athenzDomainField = "athenzDomain"; private static final String propertyField = "property"; private static final String propertyIdField = "propertyId"; + private static final String contactField = "contact"; + private static final String contactUrlField = "contactUrl"; + private static final String propertyUrlField = "propertyUrl"; + private static final String issueTrackerUrlField = "issueTrackerUrl"; + private static final String personsField = "persons"; + private static final String personField = "person"; public Slime toSlime(AthenzTenant tenant) { Slime slime = new Slime(); @@ -34,6 +45,20 @@ public class TenantSerializer { root.setString(athenzDomainField, tenant.domain().getName()); root.setString(propertyField, tenant.property().id()); tenant.propertyId().ifPresent(propertyId -> root.setString(propertyIdField, propertyId.id())); + tenant.contact().ifPresent(contact -> { + Cursor contactObject = root.setObject(contactField); + contactObject.setString(contactUrlField, contact.url().toString()); + contactObject.setString(propertyUrlField, contact.propertyUrl().toString()); + contactObject.setString(issueTrackerUrlField, contact.issueTrackerUrl().toString()); + Cursor personsArray = contactObject.setArray(personsField); + contact.persons().forEach(personList -> { + Cursor personArray = personsArray.addArray(); + personList.forEach(person -> { + Cursor personObject = personArray.addObject(); + personObject.setString(personField, person); + }); + }); + }); return slime; } @@ -50,7 +75,8 @@ public class TenantSerializer { AthenzDomain domain = new AthenzDomain(root.field(athenzDomainField).asString()); Property property = new Property(root.field(propertyField).asString()); Optional<PropertyId> propertyId = SlimeUtils.optionalString(root.field(propertyIdField)).map(PropertyId::new); - return new AthenzTenant(name, domain, property, propertyId); + Optional<Contact> contact = contactFrom(root.field(contactField)); + return new AthenzTenant(name, domain, property, propertyId, contact); } public UserTenant userTenantFrom(Slime slime) { @@ -59,4 +85,24 @@ public class TenantSerializer { return new UserTenant(name); } + private Optional<Contact> contactFrom(Inspector object) { + if (!object.valid()) { + return Optional.empty(); + } + return Optional.of(new Contact(URI.create(object.field(contactUrlField).asString()), + URI.create(object.field(propertyUrlField).asString()), + URI.create(object.field(issueTrackerUrlField).asString()), + personsFrom(object.field(personsField)))); + } + + private List<List<String>> personsFrom(Inspector array) { + List<List<String>> personLists = new ArrayList<>(); + array.traverse((ArrayTraverser) (i, personArray) -> { + List<String> persons = new ArrayList<>(); + personArray.traverse((ArrayTraverser) (j, inspector) -> persons.add(inspector.field("person").asString())); + personLists.add(persons); + }); + return personLists; + } + } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index 07286fda90b..22809ac18bf 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -16,6 +16,7 @@ import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.LoggingRequestHandler; import com.yahoo.io.IOUtils; import com.yahoo.log.LogLevel; +import com.yahoo.restapi.Path; import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; import com.yahoo.slime.Slime; @@ -50,7 +51,6 @@ import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServ import com.yahoo.vespa.hosted.controller.api.integration.configserver.Log; import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.api.integration.deployment.RunId; -import com.yahoo.vespa.hosted.controller.api.integration.organization.User; import com.yahoo.vespa.hosted.controller.api.integration.routing.RotationStatus; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; @@ -66,12 +66,12 @@ import com.yahoo.vespa.hosted.controller.application.JobStatus; import com.yahoo.vespa.hosted.controller.application.SourceRevision; import com.yahoo.vespa.hosted.controller.restapi.ErrorResponse; import com.yahoo.vespa.hosted.controller.restapi.MessageResponse; -import com.yahoo.restapi.Path; import com.yahoo.vespa.hosted.controller.restapi.ResourceResponse; import com.yahoo.vespa.hosted.controller.restapi.SlimeJsonResponse; import com.yahoo.vespa.hosted.controller.restapi.StringResponse; import com.yahoo.vespa.hosted.controller.restapi.filter.SetBouncerPassthruHeaderFilter; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; import com.yahoo.vespa.hosted.controller.tenant.Tenant; import com.yahoo.vespa.hosted.controller.tenant.UserTenant; import com.yahoo.vespa.hosted.controller.versions.VespaVersion; @@ -349,6 +349,11 @@ public class ApplicationApiHandler extends LoggingRequestHandler { private void toSlime(Cursor object, Application application, HttpRequest request) { object.setString("application", application.id().application().value()); object.setString("instance", application.id().instance().value()); + object.setString("deployments", withPath("/application/v4" + + "/tenant/" + application.id().tenant().value() + + "/application/" + application.id().application().value() + + "/instance/" + application.id().instance().value() + "/job/", + request.getUri()).toString()); // Currently deploying change if (application.change().isPresent()) { @@ -642,19 +647,27 @@ public class ApplicationApiHandler extends LoggingRequestHandler { } private HttpResponse updateTenant(String tenantName, HttpRequest request) { - Optional<AthenzTenant> existingTenant = controller.tenants().athenzTenant(TenantName.from(tenantName)); - if ( ! existingTenant.isPresent()) return ErrorResponse.notFoundError("Tenant '" + tenantName + "' does not exist"); + Optional<AthenzTenant> tenant = controller.tenants().athenzTenant(TenantName.from(tenantName)); + if ( ! tenant.isPresent()) return ErrorResponse.notFoundError("Tenant '" + tenantName + "' does not exist"); Inspector requestData = toSlime(request.getData()).get(); - AthenzTenant updatedTenant = existingTenant.get() - .with(new AthenzDomain(mandatory("athensDomain", requestData).asString())) - .with(new Property(mandatory("property", requestData).asString())); - Optional<PropertyId> propertyId = optional("propertyId", requestData).map(PropertyId::new); - if (propertyId.isPresent()) { - updatedTenant = updatedTenant.with(propertyId.get()); - } - controller.tenants().updateTenant(updatedTenant, requireNToken(request, "Could not update " + tenantName)); - return tenant(updatedTenant, request, true); + NToken token = requireNToken(request, "Could not update " + tenantName); + + controller.tenants().lockOrThrow(tenant.get().name(), lockedTenant -> { + lockedTenant = lockedTenant.with(new Property(mandatory("property", requestData).asString())); + lockedTenant = controller.tenants().withDomain( + lockedTenant, + new AthenzDomain(mandatory("athensDomain", requestData).asString()), + token + ); + Optional<PropertyId> propertyId = optional("propertyId", requestData).map(PropertyId::new); + if (propertyId.isPresent()) { + lockedTenant = lockedTenant.with(propertyId.get()); + } + controller.tenants().store(lockedTenant); + }); + + return tenant(controller.tenants().requireAthenzTenant(tenant.get().name()), request, true); } private HttpResponse createTenant(String tenantName, HttpRequest request) { @@ -897,13 +910,11 @@ public class ApplicationApiHandler extends LoggingRequestHandler { private void toSlime(Cursor object, Tenant tenant, HttpRequest request, boolean listApplications) { object.setString("tenant", tenant.name().value()); object.setString("type", tentantType(tenant)); - Optional<PropertyId> propertyId = Optional.empty(); if (tenant instanceof AthenzTenant) { AthenzTenant athenzTenant = (AthenzTenant) tenant; object.setString("athensDomain", athenzTenant.domain().getName()); object.setString("property", athenzTenant.property().id()); - propertyId = athenzTenant.propertyId(); - propertyId.ifPresent(id -> object.setString("propertyId", id.toString())); + athenzTenant.propertyId().ifPresent(id -> object.setString("propertyId", id.toString())); } Cursor applicationArray = object.setArray("applications"); if (listApplications) { // This cludge is needed because we call this after deleting the tenant. As this call makes another tenant lookup it will fail. TODO is to support lookup on tenant @@ -916,23 +927,28 @@ public class ApplicationApiHandler extends LoggingRequestHandler { } } } - propertyId.ifPresent(id -> { - try { - object.setString("propertyUrl", controller.organization().propertyUri(id).toString()); - object.setString("contactsUrl", controller.organization().contactsUri(id).toString()); - object.setString("issueCreationUrl", controller.organization().issueCreationUri(id).toString()); - Cursor lists = object.setArray("contacts"); - for (List<? extends User> contactList : controller.organization().contactsFor(id)) { - Cursor list = lists.addArray(); - for (User contact : contactList) - list.addString(contact.displayName()); + if (tenant instanceof AthenzTenant) { + AthenzTenant athenzTenant = (AthenzTenant) tenant; + Optional<Contact> contact = athenzTenant.contact(); + if (!contact.isPresent()) { // TODO: Remove this fallback once all contacts have been written once + try { + contact = controller.tenants().findContact(athenzTenant); + } catch (Exception e) { + log.log(Level.WARNING, "Failed to fetch contact information for tenant " + athenzTenant + + ": " + Exceptions.toMessageString(e)); } } - catch (RuntimeException e) { - log.log(Level.WARNING, "Error fetching property info for " + tenant + " with propertyId " + id + ": " + - Exceptions.toMessageString(e)); - } - }); + contact.ifPresent(c -> { + object.setString("propertyUrl", c.propertyUrl().toString()); + object.setString("contactsUrl", c.url().toString()); + object.setString("issueCreationUrl", c.issueTrackerUrl().toString()); + Cursor contactsArray = object.setArray("contacts"); + c.persons().forEach(persons -> { + Cursor personArray = contactsArray.addArray(); + persons.forEach(personArray::addString); + }); + }); + } } // A tenant has different content when in a list ... antipattern, but not solvable before application/v5 diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java index 620a3514c87..8a2664e61a7 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java @@ -142,13 +142,13 @@ class JobControllerApiHandlerHelper { lastVespa = version; Version lastPlatform = lastVespa.versionNumber(); - lastPlatformObject.setString("version", lastPlatform.toString()); + lastPlatformObject.setString("platform", lastPlatform.toString()); lastPlatformObject.setLong("at", lastVespa.committedAt().toEpochMilli()); long completed = steps.productionJobs().stream().filter(type -> controller.applications().deploymentTrigger().isComplete(Change.of(lastPlatform), application, type)).count(); if (Optional.of(lastPlatform).equals(change.platform())) - lastPlatformObject.setString("deploying", completed + " of " + steps.productionJobs().size()); + lastPlatformObject.setString("deploying", completed + " of " + steps.productionJobs().size() + "complete"); else if (completed == steps.productionJobs().size()) - lastPlatformObject.setString("completed", completed + " of " + steps.productionJobs().size()); + lastPlatformObject.setString("completed", completed + " of " + steps.productionJobs().size() + "complete"); else if ( ! application.deploymentSpec().canUpgradeAt(controller.clock().instant())) { lastPlatformObject.setString("blocked", application.deploymentSpec().changeBlocker().stream() .filter(blocker -> blocker.blocksVersions()) @@ -162,7 +162,7 @@ class JobControllerApiHandlerHelper { private static void lastApplicationToSlime(Cursor lastApplicationObject, Application application, Change change, DeploymentSteps steps, Controller controller) { long completed; ApplicationVersion lastApplication = application.deploymentJobs().statusOf(component).flatMap(JobStatus::lastSuccess).get().application(); - applicationVersionToSlime(lastApplicationObject.setObject("version"), lastApplication); + applicationVersionToSlime(lastApplicationObject.setObject("application"), lastApplication); lastApplicationObject.setLong("at", application.deploymentJobs().statusOf(component).flatMap(JobStatus::lastSuccess).get().at().toEpochMilli()); completed = steps.productionJobs().stream().filter(type -> controller.applications().deploymentTrigger().isComplete(Change.of(lastApplication), application, type)).count(); if (Optional.of(lastApplication).equals(change.application())) @@ -333,11 +333,12 @@ class JobControllerApiHandlerHelper { } private static void applicationVersionToSlime(Cursor versionObject, ApplicationVersion version) { - versionObject.setString("id", version.id()); - versionObject.setLong("build", version.buildNumber().get()); - versionObject.setString("repository", version.source().get().repository()); - versionObject.setString("branch", version.source().get().branch()); - versionObject.setString("commit", version.source().get().commit()); + versionObject.setString("hash", version.id()); + versionObject.setLong("build", version.buildNumber().getAsLong()); + Cursor sourceObject = versionObject.setObject("source"); + sourceObject.setString("gitRepository", version.source().get().repository()); + sourceObject.setString("gitBranch", version.source().get().branch()); + sourceObject.setString("gitCommit", version.source().get().commit()); } /** diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java index 0ba0eea2dab..8cbb4e06aca 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java @@ -6,6 +6,7 @@ import com.yahoo.vespa.athenz.api.AthenzDomain; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; +import java.util.Objects; import java.util.Optional; /** @@ -18,16 +19,19 @@ public class AthenzTenant extends Tenant { private final AthenzDomain domain; private final Property property; private final Optional<PropertyId> propertyId; + private final Optional<Contact> contact; /** * This should only be used by serialization. * Use {@link #create(TenantName, AthenzDomain, Property, Optional)}. * */ - public AthenzTenant(TenantName name, AthenzDomain domain, Property property, Optional<PropertyId> propertyId) { + public AthenzTenant(TenantName name, AthenzDomain domain, Property property, Optional<PropertyId> propertyId, + Optional<Contact> contact) { super(name); - this.domain = domain; - this.property = property; - this.propertyId = propertyId; + this.domain = Objects.requireNonNull(domain, "domain must be non-null"); + this.property = Objects.requireNonNull(property, "property must be non-null"); + this.propertyId = Objects.requireNonNull(propertyId, "propertyId must be non-null"); + this.contact = Objects.requireNonNull(contact, "contact must be non-null"); } /** Property name of this tenant */ @@ -35,11 +39,16 @@ public class AthenzTenant extends Tenant { return property; } - /** Property ID of the tenant, if present */ + /** Property ID of the tenant, if any */ public Optional<PropertyId> propertyId() { return propertyId; } + /** Contact information for this, if any */ + public Optional<Contact> contact() { + return contact; + } + /** Athenz domain of this tenant */ public AthenzDomain domain() { return domain; @@ -55,22 +64,10 @@ public class AthenzTenant extends Tenant { return "athenz tenant '" + name() + "'"; } - public AthenzTenant with(AthenzDomain domain) { - return new AthenzTenant(name(), domain, property(), propertyId()); - } - - public AthenzTenant with(Property property) { - return new AthenzTenant(name(), domain, property, propertyId()); - } - - public AthenzTenant with(PropertyId propertyId) { - return new AthenzTenant(name(), domain, property, Optional.of(propertyId)); - } - /** Create a new Athenz tenant */ public static AthenzTenant create(TenantName name, AthenzDomain domain, Property property, Optional<PropertyId> propertyId) { - return new AthenzTenant(requireName(requireNoPrefix(name)), domain, property, propertyId); + return new AthenzTenant(requireName(requireNoPrefix(name)), domain, property, propertyId, Optional.empty()); } private static TenantName requireNoPrefix(TenantName name) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Contact.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Contact.java new file mode 100644 index 00000000000..e13b0f982da --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Contact.java @@ -0,0 +1,75 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.tenant; + +import com.google.common.collect.ImmutableList; + +import java.net.URI; +import java.util.List; +import java.util.Objects; + +/** + * Contact information for a tenant. + * + * @author mpolden + */ +public class Contact { + + private final URI url; + private final URI propertyUrl; + private final URI issueTrackerUrl; + private final List<List<String>> persons; + + public Contact(URI url, URI propertyUrl, URI issueTrackerUrl, List<List<String>> persons) { + this.propertyUrl = Objects.requireNonNull(propertyUrl, "propertyUrl must be non-null"); + this.url = Objects.requireNonNull(url, "url must be non-null"); + this.issueTrackerUrl = Objects.requireNonNull(issueTrackerUrl, "issueTrackerUrl must be non-null"); + this.persons = ImmutableList.copyOf(Objects.requireNonNull(persons, "persons must be non-null")); + } + + /** URL to this */ + public URI url() { + return url; + } + + /** URL to information about this property */ + public URI propertyUrl() { + return propertyUrl; + } + + /** URL to this contacts's issue tracker */ + public URI issueTrackerUrl() { + return issueTrackerUrl; + } + + /** Nested list of persons representing this. First level represents that person's rank in the corporate dystopia. */ + public List<List<String>> persons() { + return persons; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Contact contact = (Contact) o; + return Objects.equals(url, contact.url) && + Objects.equals(propertyUrl, contact.propertyUrl) && + Objects.equals(issueTrackerUrl, contact.issueTrackerUrl) && + Objects.equals(persons, contact.persons); + } + + @Override + public int hashCode() { + return Objects.hash(url, propertyUrl, issueTrackerUrl, persons); + } + + @Override + public String toString() { + return "Contact{" + + "url=" + url + + ", propertyUrl=" + propertyUrl + + ", issueTrackerUrl=" + issueTrackerUrl + + ", persons=" + persons + + '}'; + } + +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java index c067bccb4c3..367e4e52e79 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java @@ -78,6 +78,7 @@ public final class ControllerTester { private final MockBuildService buildService; private final MetricsServiceMock metricsService; private final RoutingGeneratorMock routingGenerator; + private final MockOrganization organization; private Controller controller; @@ -87,7 +88,7 @@ public final class ControllerTester { new ZoneRegistryMock(), new GitHubMock(), curatorDb, rotationsConfig, new MemoryNameService(), new ArtifactRepositoryMock(), new ApplicationStoreMock(), new MemoryEntityService(), new MockBuildService(), - metricsService, new RoutingGeneratorMock()); + metricsService, new RoutingGeneratorMock(), new MockOrganization(clock)); } public ControllerTester(ManualClock clock) { @@ -112,7 +113,8 @@ public final class ControllerTester { MemoryNameService nameService, ArtifactRepositoryMock artifactRepository, ApplicationStoreMock appStoreMock, EntityService entityService, MockBuildService buildService, - MetricsServiceMock metricsService, RoutingGeneratorMock routingGenerator) { + MetricsServiceMock metricsService, RoutingGeneratorMock routingGenerator, + MockOrganization organization) { this.athenzDb = athenzDb; this.clock = clock; this.configServer = configServer; @@ -127,9 +129,10 @@ public final class ControllerTester { this.buildService = buildService; this.metricsService = metricsService; this.routingGenerator = routingGenerator; + this.organization = organization; this.controller = createController(curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, athenzDb, nameService, artifactRepository, appStoreMock, entityService, buildService, - metricsService, routingGenerator); + metricsService, routingGenerator, organization); // Make root logger use time from manual clock configureDefaultLogHandler(handler -> handler.setFilter( @@ -175,11 +178,15 @@ public final class ControllerTester { public RoutingGeneratorMock routingGenerator() { return routingGenerator; } + public MockOrganization organization() { + return organization; + } + /** Create a new controller instance. Useful to verify that controller state is rebuilt from persistence */ public final void createNewController() { controller = createController(curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, athenzDb, nameService, artifactRepository, applicationStore, entityService, buildService, metricsService, - routingGenerator); + routingGenerator, organization); } /** Creates the given tenant and application and deploys it */ @@ -197,12 +204,6 @@ public final class ControllerTester { } /** Creates the given tenant and application and deploys it */ - public Application createAndDeploy(String tenantName, String domainName, String applicationName, - String instanceName, Environment environment, long projectId, Long propertyId) { - return createAndDeploy(tenantName, domainName, applicationName, instanceName, toZone(environment), projectId, propertyId); - } - - /** Creates the given tenant and application and deploys it */ public Application createAndDeploy(String tenantName, String domainName, String applicationName, ZoneId zone, long projectId, Long propertyId) { return createAndDeploy(tenantName, domainName, applicationName, "default", zone, projectId, propertyId); } @@ -295,12 +296,12 @@ public final class ControllerTester { ArtifactRepository artifactRepository, ApplicationStore applicationStore, EntityService entityService, BuildService buildService, MetricsServiceMock metricsService, - RoutingGenerator routingGenerator) { + RoutingGenerator routingGenerator, MockOrganization organization) { Controller controller = new Controller(curator, rotationsConfig, gitHub, entityService, - new MockOrganization(clock), + organization, new MemoryGlobalRoutingService(), zoneRegistryMock, configServer, diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java index f5fc6825960..3b381e21b27 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java @@ -375,9 +375,9 @@ public class DeploymentTriggerTest { // This component completion should remove the older outstanding change, to avoid a later downgrade. clock.advance(Duration.ofHours(1)); tester.deployAndNotify(application, applicationPackage, true, productionUsWest1); - assertEquals((Long) BuildJob.defaultBuildNumber, tester.application(application.id()).deploymentJobs().jobStatus() - .get(productionUsWest1).lastSuccess().get().application().buildNumber().get()); - assertEquals((Long) (BuildJob.defaultBuildNumber + 1), tester.application(application.id()).outstandingChange().application().get().buildNumber().get()); + assertEquals(BuildJob.defaultBuildNumber, tester.application(application.id()).deploymentJobs().jobStatus() + .get(productionUsWest1).lastSuccess().get().application().buildNumber().getAsLong()); + assertEquals((BuildJob.defaultBuildNumber + 1), tester.application(application.id()).outstandingChange().application().get().buildNumber().getAsLong()); tester.readyJobTrigger().maintain(); assertTrue(tester.buildService().jobs().isEmpty()); @@ -513,14 +513,14 @@ public class DeploymentTriggerTest { tester.assertRunning(productionUsCentral1, application.id()); assertEquals(v2, app.get().deployments().get(productionUsCentral1.zone(main)).version()); - assertEquals(Long.valueOf(42L), app.get().deployments().get(productionUsCentral1.zone(main)).applicationVersion().buildNumber().get()); + assertEquals(42, app.get().deployments().get(productionUsCentral1.zone(main)).applicationVersion().buildNumber().getAsLong()); assertNotEquals(triggered, app.get().deploymentJobs().jobStatus().get(productionUsCentral1).lastTriggered().get().at()); // Change has a higher application version than what is deployed -- deployment should trigger. tester.deployAndNotify(application, applicationPackage, false, productionUsCentral1); tester.deploy(productionUsCentral1, application, applicationPackage); assertEquals(v2, app.get().deployments().get(productionUsCentral1.zone(main)).version()); - assertEquals(Long.valueOf(43), app.get().deployments().get(productionUsCentral1.zone(main)).applicationVersion().buildNumber().get()); + assertEquals(43, app.get().deployments().get(productionUsCentral1.zone(main)).applicationVersion().buildNumber().getAsLong()); // Change is again strictly dominated, and us-central-1 is skipped, even though it is still failing. tester.clock().advance(Duration.ofHours(2).plus(Duration.ofSeconds(1))); // Enough time for retry @@ -588,8 +588,8 @@ public class DeploymentTriggerTest { tester.deployAndNotify(application, true, productionUsEast3); tester.deployAndNotify(application, true, productionEuWest1); assertFalse(app.get().change().isPresent()); - assertEquals(43, app.get().deploymentJobs().jobStatus().get(productionEuWest1).lastSuccess().get().application().buildNumber().get().longValue()); - assertEquals(43, app.get().deploymentJobs().jobStatus().get(productionUsEast3).lastSuccess().get().application().buildNumber().get().longValue()); + assertEquals(43, app.get().deploymentJobs().jobStatus().get(productionEuWest1).lastSuccess().get().application().buildNumber().getAsLong()); + assertEquals(43, app.get().deploymentJobs().jobStatus().get(productionUsEast3).lastSuccess().get().application().buildNumber().getAsLong()); } @Test diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java index dc9f3246e80..458ba49f3e3 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java @@ -110,7 +110,17 @@ public class ConfigServerMock extends AbstractComponent implements ConfigServer /** Converge all services belonging to the given application */ public void convergeServices(ApplicationId application, ZoneId zone) { - serviceStatus.put(new DeploymentId(application, zone), new ServiceConvergence(application, zone, true)); + List<Node> nodes = nodeRepository.list(zone, application); + serviceStatus.put(new DeploymentId(application, zone), new ServiceConvergence(application, + zone, + true, + 2, + nodes.stream() + .map(node -> new ServiceConvergence.Status(node.hostname(), + 43, + "container", + 2)) + .collect(Collectors.toList()))); } /** The version given in the previous prepare call, or empty if no call has been made */ @@ -189,14 +199,24 @@ public class ConfigServerMock extends AbstractComponent implements ConfigServer public PrepareResponse prepareResponse() { Application application = applications.get(deployment.applicationId()); application.activate(); - for (Node node : nodeRepository.list(deployment.zoneId(), deployment.applicationId())) { + List<Node> nodes = nodeRepository.list(deployment.zoneId(), deployment.applicationId()); + for (Node node : nodes) { nodeRepository.putByHostname(deployment.zoneId(), new Node(node.hostname(), node.state(), node.type(), node.owner(), node.currentVersion(), application.version().get())); } - serviceStatus.remove(deployment); // Deployment is no longer converging after new deployment + serviceStatus.put(deployment, new ServiceConvergence(deployment.applicationId(), + deployment.zoneId(), + false, + 2, + nodes.stream() + .map(node -> new ServiceConvergence.Status(node.hostname(), + 43, + "container", + 1)) + .collect(Collectors.toList()))); PrepareResponse prepareResponse = new PrepareResponse(); prepareResponse.message = "foo"; @@ -223,6 +243,7 @@ public class ConfigServerMock extends AbstractComponent implements ConfigServer applications.remove(deployment.applicationId()); nodeRepository().removeByHostname(deployment.zoneId(), nodeRepository().list(deployment.zoneId(), deployment.applicationId())); + serviceStatus.remove(deployment); } // Returns a canned example response diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainerTest.java new file mode 100644 index 00000000000..e67fa6c2b46 --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ContactInformationMaintainerTest.java @@ -0,0 +1,75 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.maintenance; + +import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.hosted.controller.ControllerTester; +import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; +import com.yahoo.vespa.hosted.controller.api.integration.organization.User; +import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; +import org.junit.Before; +import org.junit.Test; + +import java.net.URI; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * @author mpolden + */ +public class ContactInformationMaintainerTest { + + private ControllerTester tester; + private ContactInformationMaintainer maintainer; + + @Before + public void before() { + tester = new ControllerTester(); + maintainer = new ContactInformationMaintainer(tester.controller(), Duration.ofDays(1), new JobControl(tester.controller().curator())); + } + + @Test + public void updates_contact_information() { + long propertyId = 1; + TenantName name = tester.createTenant("tenant1", "domain1", propertyId); + Supplier<AthenzTenant> tenant = () -> tester.controller().tenants().requireAthenzTenant(name); + assertFalse("No contact information initially", tenant.get().contact().isPresent()); + + Contact contact = testContact(); + registerContact(propertyId, contact); + maintainer.run(); + + assertTrue("Contact information added", tenant.get().contact().isPresent()); + assertEquals(contact, tenant.get().contact().get()); + } + + private void registerContact(long propertyId, Contact contact) { + PropertyId p = new PropertyId(String.valueOf(propertyId)); + tester.organization().addProperty(p) + .setContactsUrl(p, contact.url()) + .setIssueUrl(p, contact.issueTrackerUrl()) + .setPropertyUrl(p, contact.propertyUrl()) + .setContactsFor(p, contact.persons().stream().map(persons -> persons.stream() + .map(User::from) + .collect(Collectors.toList())) + .collect(Collectors.toList())); + } + + private static Contact testContact() { + URI contactUrl = URI.create("http://contact1.test"); + URI issueTrackerUrl = URI.create("http://issue-tracker1.test"); + URI propertyUrl = URI.create("http://property1.test"); + List<List<String>> persons = Arrays.asList(Collections.singletonList("alice"), + Collections.singletonList("bob")); + return new Contact(contactUrl, propertyUrl, issueTrackerUrl, persons); + } + +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/JobRunnerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/JobRunnerTest.java index a77f0789314..0b2863dab1d 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/JobRunnerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/JobRunnerTest.java @@ -75,7 +75,6 @@ public class JobRunnerTest { public void multiThreadedExecutionFinishes() throws InterruptedException { DeploymentTester tester = new DeploymentTester(); JobController jobs = tester.controller().jobController(); - // Fail the installation of the initial version of the real application in staging tests, and succeed everything else. StepRunner stepRunner = (step, id) -> id.type() == stagingTest && step.get() == startTests? Optional.of(error) : Optional.of(running); CountDownLatch latch = new CountDownLatch(19); // Number of steps that will run, below: all but endTests in staging and all 9 in system. JobRunner runner = new JobRunner(tester.controller(), Duration.ofDays(1), new JobControl(tester.controller().curator()), @@ -93,9 +92,10 @@ public class JobRunnerTest { jobs.start(id, stagingTest, versions); assertTrue(jobs.last(id, systemTest).get().steps().values().stream().allMatch(unfinished::equals)); - runner.maintain(); assertFalse(jobs.last(id, systemTest).get().hasEnded()); + assertTrue(jobs.last(id, stagingTest).get().steps().values().stream().allMatch(unfinished::equals)); assertFalse(jobs.last(id, stagingTest).get().hasEnded()); + runner.maintain(); latch.await(1, TimeUnit.SECONDS); assertEquals(0, latch.getCount()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/RunSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/RunSerializerTest.java index 82aee3b3550..de9fe3f3dcc 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/RunSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/RunSerializerTest.java @@ -30,12 +30,13 @@ import static com.yahoo.vespa.hosted.controller.deployment.Step.deactivateTester import static com.yahoo.vespa.hosted.controller.deployment.Step.deployInitialReal; import static com.yahoo.vespa.hosted.controller.deployment.Step.deployReal; import static com.yahoo.vespa.hosted.controller.deployment.Step.deployTester; +import static com.yahoo.vespa.hosted.controller.deployment.Step.endTests; import static com.yahoo.vespa.hosted.controller.deployment.Step.installInitialReal; import static com.yahoo.vespa.hosted.controller.deployment.Step.installReal; import static com.yahoo.vespa.hosted.controller.deployment.Step.installTester; import static com.yahoo.vespa.hosted.controller.deployment.Step.report; import static com.yahoo.vespa.hosted.controller.deployment.Step.startTests; -import static com.yahoo.vespa.hosted.controller.deployment.Step.endTests; +import static java.time.temporal.ChronoUnit.MILLIS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -97,7 +98,7 @@ public class RunSerializerTest { .build(), run.steps()); - run = run.aborted().finished(Instant.now()); + run = run.aborted().finished(Instant.now().truncatedTo(MILLIS)); assertEquals(aborted, run.status()); assertTrue(run.hasEnded()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java index fd909482072..38b09024cdf 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java @@ -5,9 +5,13 @@ import com.yahoo.vespa.athenz.api.AthenzDomain; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Contact; import com.yahoo.vespa.hosted.controller.tenant.UserTenant; import org.junit.Test; +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -47,6 +51,25 @@ public class TenantSerializerTest { } @Test + public void athenz_tenant_with_contact() { + AthenzTenant tenant = new AthenzTenant(TenantName.from("athenz-tenant"), + new AthenzDomain("domain1"), + new Property("property1"), + Optional.of(new PropertyId("1")), + Optional.of(new Contact( + URI.create("http://contact1.test"), + URI.create("http://property1.test"), + URI.create("http://issue-tracker-1.test"), + Arrays.asList( + Collections.singletonList("person1"), + Collections.singletonList("person2") + ) + ))); + AthenzTenant serialized = serializer.athenzTenantFrom(serializer.toSlime(tenant)); + assertEquals(tenant.contact(), serialized.contact()); + } + + @Test public void user_tenant() { UserTenant tenant = UserTenant.create("by-foo"); UserTenant serialized = serializer.userTenantFrom(serializer.toSlime(tenant)); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java index 017479ecc90..13092451d4b 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java @@ -46,6 +46,8 @@ import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzDbMock; import com.yahoo.vespa.hosted.controller.deployment.ApplicationPackageBuilder; import com.yahoo.vespa.hosted.controller.deployment.BuildJob; import com.yahoo.vespa.hosted.controller.integration.ConfigServerMock; +import com.yahoo.vespa.hosted.controller.maintenance.ContactInformationMaintainer; +import com.yahoo.vespa.hosted.controller.maintenance.JobControl; import com.yahoo.vespa.hosted.controller.restapi.ContainerControllerTester; import com.yahoo.vespa.hosted.controller.restapi.ContainerTester; import com.yahoo.vespa.hosted.controller.restapi.ControllerContainerTest; @@ -53,6 +55,7 @@ import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; import org.apache.http.HttpEntity; import org.apache.http.entity.ContentType; import org.apache.http.entity.mime.MultipartEntityBuilder; +import org.junit.Before; import org.junit.Test; import java.io.ByteArrayOutputStream; @@ -61,6 +64,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.net.URI; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -108,10 +112,18 @@ public class ApplicationApiTest extends ControllerContainerTest { private static final ZoneId TEST_ZONE = ZoneId.from(Environment.test, RegionName.from("us-east-1")); private static final ZoneId STAGING_ZONE = ZoneId.from(Environment.staging, RegionName.from("us-east-3")); + + private ContainerControllerTester controllerTester; + private ContainerTester tester; + + @Before + public void before() { + controllerTester = new ContainerControllerTester(container, responseFiles); + tester = controllerTester.containerTester(); + } + @Test - public void testApplicationApi() throws Exception { - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); + public void testApplicationApi() { tester.computeVersionStatus(); createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, USER_ID); // (Necessary but not provided in this API) @@ -151,7 +163,8 @@ public class ApplicationApiTest extends ControllerContainerTest { // Add another Athens domain, so we can try to create more tenants createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN_2, USER_ID); // New domain to test tenant w/property ID // Add property info for that property id, as well, in the mock organization. - addPropertyData((MockOrganization) controllerTester.controller().organization(), "1234"); + registerContact(1234); + // POST (add) a tenant with property ID tester.assertResponse(request("/application/v4/tenant/tenant2", POST) .userIdentity(USER_ID) @@ -164,9 +177,10 @@ public class ApplicationApiTest extends ControllerContainerTest { .nToken(N_TOKEN) .data("{\"athensDomain\":\"domain2\", \"property\":\"property2\", \"propertyId\":\"1234\"}"), new File("tenant-without-applications-with-id.json")); - // GET a tenant with property ID + // GET a tenant with property ID and contact information + updateContactInformation(); tester.assertResponse(request("/application/v4/tenant/tenant2", GET).userIdentity(USER_ID), - new File("tenant-without-applications-with-id.json")); + new File("tenant-with-contact-info.json")); // POST (create) an application tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1", POST) @@ -465,8 +479,6 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void testDeployDirectly() { // Setup - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); tester.computeVersionStatus(); createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, USER_ID); @@ -500,8 +512,6 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void testDeployDirectlyUsingOneCallForDeploy() { // Setup - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); tester.computeVersionStatus(); UserId userId = new UserId("new_user"); createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, userId); @@ -523,10 +533,8 @@ public class ApplicationApiTest extends ControllerContainerTest { } @Test - public void testSortsDeploymentsAndJobs() throws Exception { + public void testSortsDeploymentsAndJobs() { // Setup - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); tester.computeVersionStatus(); createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, USER_ID); @@ -602,7 +610,6 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void testErrorResponses() throws Exception { - ContainerTester tester = new ContainerTester(container, responseFiles); tester.computeVersionStatus(); createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, USER_ID); @@ -749,7 +756,7 @@ public class ApplicationApiTest extends ControllerContainerTest { // Create legancy tenant name containing underscores tester.controller().tenants().create(new AthenzTenant(TenantName.from("my_tenant"), ATHENZ_TENANT_DOMAIN, - new Property("property1"), Optional.empty()), + new Property("property1"), Optional.empty(), Optional.empty()), N_TOKEN); // POST (add) a Athenz tenant with dashes duplicates existing one with underscores tester.assertResponse(request("/application/v4/tenant/my-tenant", POST) @@ -761,8 +768,7 @@ public class ApplicationApiTest extends ControllerContainerTest { } @Test - public void testAuthorization() throws Exception { - ContainerTester tester = new ContainerTester(container, responseFiles); + public void testAuthorization() { UserId authorizedUser = USER_ID; UserId unauthorizedUser = new UserId("othertenant"); @@ -855,9 +861,7 @@ public class ApplicationApiTest extends ControllerContainerTest { } @Test - public void deployment_fails_on_illegal_domain_in_deployment_spec() throws IOException { - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); + public void deployment_fails_on_illegal_domain_in_deployment_spec() { ApplicationPackage applicationPackage = new ApplicationPackageBuilder() .upgradePolicy("default") .athenzIdentity(com.yahoo.config.provision.AthenzDomain.from("invalid.domain"), com.yahoo.config.provision.AthenzService.from("service")) @@ -881,8 +885,6 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void deployment_succeeds_when_correct_domain_is_used() { - ContainerControllerTester controllerTester = new ContainerControllerTester(container, responseFiles); - ContainerTester tester = controllerTester.containerTester(); ApplicationPackage applicationPackage = new ApplicationPackageBuilder() .upgradePolicy("default") .athenzIdentity(com.yahoo.config.provision.AthenzDomain.from("domain1"), com.yahoo.config.provision.AthenzService.from("service")) @@ -912,11 +914,10 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void testJobStatusReporting() { - ContainerControllerTester tester = new ContainerControllerTester(container, responseFiles); addUserToHostedOperatorRole(HostedAthenzIdentities.from(HOSTED_VESPA_OPERATOR)); - tester.containerTester().computeVersionStatus(); + tester.computeVersionStatus(); long projectId = 1; - Application app = tester.createApplication(); + Application app = controllerTester.createApplication(); ApplicationPackage applicationPackage = new ApplicationPackageBuilder() .environment(Environment.prod) .region("corp-us-east-1") @@ -924,11 +925,11 @@ public class ApplicationApiTest extends ControllerContainerTest { Version vespaVersion = new Version("6.1"); // system version from mock config server client - BuildJob job = new BuildJob(report -> notifyCompletion(report, tester), tester.artifactRepository()) + BuildJob job = new BuildJob(report -> notifyCompletion(report, controllerTester), controllerTester.artifactRepository()) .application(app) .projectId(projectId); job.type(JobType.component).uploadArtifact(applicationPackage).submit(); - tester.deploy(app, applicationPackage, TEST_ZONE); + controllerTester.deploy(app, applicationPackage, TEST_ZONE); job.type(JobType.systemTest).submit(); // Notifying about unknown job fails @@ -936,7 +937,7 @@ public class ApplicationApiTest extends ControllerContainerTest { .data(asJson(job.type(JobType.productionUsEast3).report())) .userIdentity(HOSTED_VESPA_OPERATOR) .get(); - tester.containerTester().assertResponse(request, new File("jobreport-unexpected-completion.json"), 400); + tester.assertResponse(request, new File("jobreport-unexpected-completion.json"), 400); // ... and assert it was recorded JobStatus recordedStatus = @@ -960,25 +961,24 @@ public class ApplicationApiTest extends ControllerContainerTest { @Test public void testJobStatusReportingOutOfCapacity() { - ContainerControllerTester tester = new ContainerControllerTester(container, responseFiles); - tester.containerTester().computeVersionStatus(); + controllerTester.containerTester().computeVersionStatus(); long projectId = 1; - Application app = tester.createApplication(); + Application app = controllerTester.createApplication(); ApplicationPackage applicationPackage = new ApplicationPackageBuilder() .environment(Environment.prod) .region("corp-us-east-1") .build(); // Report job failing with out of capacity - BuildJob job = new BuildJob(report -> notifyCompletion(report, tester), tester.artifactRepository()) + BuildJob job = new BuildJob(report -> notifyCompletion(report, controllerTester), controllerTester.artifactRepository()) .application(app) .projectId(projectId); job.type(JobType.component).uploadArtifact(applicationPackage).submit(); - tester.deploy(app, applicationPackage, TEST_ZONE); + controllerTester.deploy(app, applicationPackage, TEST_ZONE); job.type(JobType.systemTest).submit(); - tester.deploy(app, applicationPackage, STAGING_ZONE); + controllerTester.deploy(app, applicationPackage, STAGING_ZONE); job.type(JobType.stagingTest).error(DeploymentJobs.JobError.outOfCapacity).submit(); // Appropriate error is recorded @@ -1134,7 +1134,7 @@ public class ApplicationApiTest extends ControllerContainerTest { private void startAndTestChange(ContainerControllerTester controllerTester, ApplicationId application, long projectId, ApplicationPackage applicationPackage, - HttpEntity deployData, long buildNumber) throws IOException { + HttpEntity deployData, long buildNumber) { ContainerTester tester = controllerTester.containerTester(); // Trigger application change @@ -1208,11 +1208,22 @@ public class ApplicationApiTest extends ControllerContainerTest { } } - private void addPropertyData(MockOrganization organization, String propertyIdValue) { - PropertyId propertyId = new PropertyId(propertyIdValue); - organization.addProperty(propertyId); - organization.setContactsFor(propertyId, Arrays.asList(Collections.singletonList(User.from("alice")), - Collections.singletonList(User.from("bob")))); + private MockOrganization organization() { + return (MockOrganization) tester.container().components().getComponent(MockOrganization.class.getName()); + } + + private void updateContactInformation() { + new ContactInformationMaintainer(tester.controller(), Duration.ofDays(1), new JobControl(tester.controller().curator())).run(); + } + + private void registerContact(long propertyId) { + PropertyId p = new PropertyId(String.valueOf(propertyId)); + organization().addProperty(p) + .setIssueUrl(p, URI.create("www.issues.tld/" + p.id())) + .setContactsUrl(p, URI.create("www.contacts.tld/" + p.id())) + .setPropertyUrl(p, URI.create("www.properties.tld/" + p.id())) + .setContactsFor(p, Arrays.asList(Collections.singletonList(User.from("alice")), + Collections.singletonList(User.from("bob")))); } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelperTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelperTest.java index 01f9ea9dfa0..f6b33940929 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelperTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelperTest.java @@ -2,11 +2,7 @@ package com.yahoo.vespa.hosted.controller.restapi.application; import com.yahoo.component.Version; import com.yahoo.container.jdisc.HttpResponse; -import com.yahoo.vespa.hosted.controller.api.application.v4.model.configserverbindings.ConfigChangeActions; -import com.yahoo.vespa.hosted.controller.api.application.v4.model.configserverbindings.RefeedAction; -import com.yahoo.vespa.hosted.controller.api.identifiers.DeploymentId; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServerException; -import com.yahoo.vespa.hosted.controller.api.integration.deployment.TesterCloud; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.deployment.InternalDeploymentTester; @@ -32,14 +28,10 @@ import static com.yahoo.vespa.hosted.controller.api.integration.deployment.JobTy import static com.yahoo.vespa.hosted.controller.api.integration.deployment.TesterCloud.Status.FAILURE; import static com.yahoo.vespa.hosted.controller.deployment.InternalDeploymentTester.appId; import static com.yahoo.vespa.hosted.controller.deployment.JobController.testerOf; -import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.aborted; import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.deploymentFailed; -import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.error; import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.installationFailed; import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.running; import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.testFailure; -import static java.util.Collections.emptyList; -import static java.util.Collections.singletonList; import static org.junit.Assert.assertEquals; /** diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json index ee54e2741ba..07a3dbb7f95 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json @@ -1,6 +1,7 @@ { "application": "application1", "instance": "default", + "deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/", "deploymentJobs": [ { "type": "component", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json index c93ff6a0dd2..0d7607f1df6 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json @@ -1,6 +1,7 @@ { "application": "application1", "instance": "default", + "deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/", "deploying": { "revision": { "hash": "(ignore)", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json index a1bd96e46d2..4e4a870662a 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json @@ -1,6 +1,7 @@ { "application": "application1", "instance": "default", + "deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/", "deploying": { "revision": { "hash": "(ignore)", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json index fa51d645cfc..837c46aaec1 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json @@ -1,6 +1,7 @@ { "application": "application2", "instance": "default", + "deployments": "http://localhost:8080/application/v4/tenant/tenant2/application/application2/instance/default/job/", "deploying": { "version": "(ignore)" }, diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/overview.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/overview.json index b7a7dcbf796..a7c3135fbf3 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/overview.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/overview.json @@ -1,17 +1,19 @@ { "lastVersions": { "platform": { - "version": "7.1", + "platform": "7.1", "at": 0, "pending": "Waiting for current deployment to complete" }, "application": { - "version": { - "id": "1.0.3-commit1", + "application": { + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "at": 2000, "deploying": "0 of 3 complete" @@ -19,11 +21,13 @@ }, "deploying": { "application": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } } }, "deployments": [ @@ -32,11 +36,13 @@ "at": 2000, "platform": "6.1", "application": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "verified": false, "status": "verifying" @@ -47,11 +53,13 @@ "at": 1000, "platform": "6.1", "application": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "verified": false, "status": "pending" @@ -60,11 +68,13 @@ "at": 0, "platform": "6.1", "application": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "verified": true, "status": "pending" @@ -81,19 +91,23 @@ "end": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -109,19 +123,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -137,11 +155,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -159,19 +179,23 @@ "status": "pending", "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "cooldown": "failed" @@ -185,19 +209,23 @@ "end": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": {}, "log": "https://some.url:43/root/staging-test/run/4" @@ -209,19 +237,23 @@ "end": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -237,19 +269,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -265,11 +301,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -289,19 +327,23 @@ "start": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -316,19 +358,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -344,11 +390,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -366,19 +414,23 @@ "status": "pending", "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "production-us-central-1": "running" @@ -391,19 +443,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -419,11 +475,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", @@ -441,19 +499,23 @@ "status": "pending", "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "staging-test": "failed", @@ -467,19 +529,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "failed" @@ -493,11 +559,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "tasks": { "deploy": "succeeded", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/staging-runs.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/staging-runs.json index 448411b3912..8c5e5253482 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/staging-runs.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/staging-runs.json @@ -6,11 +6,13 @@ "end": 0, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "steps": { "deployInitialReal": "succeeded", @@ -39,19 +41,23 @@ "end": 1000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "steps": { "deployInitialReal": "succeeded", @@ -80,19 +86,23 @@ "end": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.2-commit1", + "hash": "1.0.2-commit1", "build": 2, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "steps": { "deployInitialReal": "succeeded", @@ -121,19 +131,23 @@ "end": 2000, "wantedPlatform": "6.1", "wantedApplication": { - "id": "1.0.3-commit1", + "hash": "1.0.3-commit1", "build": 3, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "currentPlatform": "6.1", "currentApplication": { - "id": "1.0.1-commit1", + "hash": "1.0.1-commit1", "build": 1, - "repository": "repository1", - "branch": "master", - "commit": "commit1" + "source": { + "gitRepository": "repository1", + "gitBranch": "master", + "gitCommit": "commit1" + } }, "steps": { "deployInitialReal": "succeeded", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-contact-info.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-contact-info.json new file mode 100644 index 00000000000..0ba0a01c5d0 --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-contact-info.json @@ -0,0 +1,19 @@ +{ + "tenant": "tenant2", + "type": "ATHENS", + "athensDomain": "domain2", + "property": "property2", + "propertyId": "1234", + "applications": [], + "propertyUrl": "www.properties.tld/1234", + "contactsUrl": "www.contacts.tld/1234", + "issueCreationUrl": "www.issues.tld/1234", + "contacts": [ + [ + "alice" + ], + [ + "bob" + ] + ] +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json index 2b847010482..6a71e524ae4 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json @@ -10,6 +10,9 @@ "name": "ClusterUtilizationMaintainer" }, { + "name": "ContactInformationMaintainer" + }, + { "name": "DefaultOsUpgrader" }, { diff --git a/controller-server/src/test/resources/test_runner_services.xml-cd b/controller-server/src/test/resources/test_runner_services.xml-cd index 9c6cfe6fe2d..e0fca9716eb 100644 --- a/controller-server/src/test/resources/test_runner_services.xml-cd +++ b/controller-server/src/test/resources/test_runner_services.xml-cd @@ -37,6 +37,6 @@ </filtering> </http> - <nodes count="1" flavor="d-2-8-50" /> + <nodes count="1" flavor="d-1-4-50" /> </container> </services> diff --git a/docker-api/pom.xml b/docker-api/pom.xml index 64410c32f06..74e463ef157 100644 --- a/docker-api/pom.xml +++ b/docker-api/pom.xml @@ -18,16 +18,20 @@ <name>${project.artifactId}</name> <dependencies> + <!-- Provided --> <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>container-dev</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> + + <!-- Compile --> <dependency> <groupId>com.github.docker-java</groupId> <artifactId>docker-java</artifactId> <version>3.0.13</version> + <scope>compile</scope> <exclusions> <exclusion> <groupId>org.slf4j</groupId> @@ -92,6 +96,7 @@ <dependency> <groupId>net.jpountz.lz4</groupId> <artifactId>lz4</artifactId> + <scope>compile</scope> </dependency> <dependency> <groupId>org.apache.httpcomponents</groupId> @@ -100,6 +105,7 @@ docker-java so the dependency is declared closer to the root of maven and more likely be the version that is finally being used. --> <version>4.4.1</version> + <scope>compile</scope> </dependency> <dependency> <groupId>org.apache.httpcomponents</groupId> @@ -108,7 +114,10 @@ docker-java so the dependency is declared closer to the root of maven and more likely be the version that is finally being used. --> <version>4.5</version> + <scope>compile</scope> </dependency> + + <!-- Test --> <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> @@ -121,7 +130,6 @@ </dependency> </dependencies> - <build> <plugins> <plugin> diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImpl.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImpl.java index 260e2da7c59..d95f7b7b8e1 100644 --- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImpl.java +++ b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImpl.java @@ -107,7 +107,14 @@ class CreateContainerCommandImpl implements Docker.CreateContainerCommand { @Override public Docker.CreateContainerCommand withVolume(String path, String volumePath) { assert path.indexOf(':') == -1; - volumeBindSpecs.add(path + ":" + volumePath); + volumeBindSpecs.add(path + ":" + volumePath + ":Z"); + return this; + } + + @Override + public Docker.CreateContainerCommand withSharedVolume(String path, String volumePath) { + assert path.indexOf(':') == -1; + volumeBindSpecs.add(path + ":" + volumePath + ":z"); return this; } diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/Docker.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/Docker.java index 91d5125eba3..5e8a0feb099 100644 --- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/Docker.java +++ b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/Docker.java @@ -19,7 +19,30 @@ public interface Docker { interface CreateContainerCommand { CreateContainerCommand withLabel(String name, String value); CreateContainerCommand withEnvironment(String name, String value); + + /** + * Mounts a directory on host inside the docker container. + * + * <p>Bind mount content will be <b>private</b> to this container (and host) only. + * + * <p>When using this method and selinux is enabled (/usr/sbin/sestatus), starting + * multiple containers which mount host's /foo directory into the container, will make + * /foo's content visible/readable/writable only inside the container which was last + * started and on the host. All the other containers will get "Permission denied". + * + * <p>Use {@link #withSharedVolume(String, String)} to mount a given host directory + * into multiple containers. + */ CreateContainerCommand withVolume(String path, String volumePath); + + /** + * Mounts a directory on host inside the docker container. + * + * <p>The bind mount content will be <b>shared</b> among multiple containers. + * + * @see #withVolume(String, String) + */ + CreateContainerCommand withSharedVolume(String path, String volumePath); CreateContainerCommand withNetworkMode(String mode); CreateContainerCommand withIpAddress(InetAddress address); CreateContainerCommand withUlimit(String name, int softLimit, int hardLimit); diff --git a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImplTest.java b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImplTest.java index 0d8701ac43c..5ce8c6b093c 100644 --- a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImplTest.java +++ b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/CreateContainerCommandImplTest.java @@ -46,7 +46,7 @@ public class CreateContainerCommandImplTest { "--ulimit nproc=10:20 " + "--env env1=val1 " + "--env env2=val2 " + - "--volume vol1:/host/vol1 " + + "--volume vol1:/host/vol1:Z " + "--cap-add SYS_ADMIN " + "--cap-add SYS_PTRACE " + "--cap-drop NET_ADMIN " + diff --git a/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java b/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java index 80f1f003412..0f3f3938701 100644 --- a/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java +++ b/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java @@ -34,7 +34,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Logger; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class MbusRequestContext implements RequestContext, ResponseHandler { diff --git a/document/src/main/java/com/yahoo/document/DocumentUpdate.java b/document/src/main/java/com/yahoo/document/DocumentUpdate.java index 70c5410534e..ad93942c1c0 100644 --- a/document/src/main/java/com/yahoo/document/DocumentUpdate.java +++ b/document/src/main/java/com/yahoo/document/DocumentUpdate.java @@ -7,6 +7,7 @@ import com.yahoo.document.serialization.DocumentSerializerFactory; import com.yahoo.document.serialization.DocumentUpdateReader; import com.yahoo.document.serialization.DocumentUpdateWriter; import com.yahoo.document.update.AssignValueUpdate; +import com.yahoo.document.update.ClearValueUpdate; import com.yahoo.document.update.FieldUpdate; import com.yahoo.document.update.ValueUpdate; import com.yahoo.io.GrowableByteBuffer; @@ -137,9 +138,20 @@ public class DocumentUpdate extends DocumentOperation implements Iterable<FieldP ValueUpdate last = update.getValueUpdate(update.size() - 1); if (last instanceof AssignValueUpdate) { FieldValue currentValue = doc.getFieldValue(update.getField()); - if ((currentValue != null) && (currentValue.compareTo(last.getValue()) == 0)) { + if ((currentValue != null) && currentValue.equals(last.getValue())) { iter.remove(); } + } else if (last instanceof ClearValueUpdate) { + FieldValue currentValue = doc.getFieldValue(update.getField()); + if (currentValue == null) { + iter.remove(); + } else { + FieldValue copy = currentValue.clone(); + copy.clear(); + if (currentValue.equals(copy)) { + iter.remove(); + } + } } } } diff --git a/document/src/main/java/com/yahoo/document/datatypes/MapFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/MapFieldValue.java index 6d6c18755c1..bf40546d637 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/MapFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/MapFieldValue.java @@ -10,7 +10,16 @@ import com.yahoo.document.serialization.FieldReader; import com.yahoo.document.serialization.FieldWriter; import com.yahoo.document.serialization.XmlSerializationHelper; import com.yahoo.document.serialization.XmlStream; -import java.util.*; + +import java.util.Arrays; +import java.util.Map; +import java.util.HashMap; +import java.util.Set; +import java.util.List; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; + /** * Vespa map. Backed by and and parametrized by FieldValue @@ -72,10 +81,7 @@ public class MapFieldValue<K extends FieldValue, V extends FieldValue> extends C */ public boolean equals(Object o) { if (!(o instanceof MapFieldValue)) return false; - MapFieldValue otherSet = (MapFieldValue) o; - Map<K, V> map1 = values; - Map<K, V> map2 = otherSet.values; - return (super.equals(o) && map1.equals(map2)); + return super.equals(o) && values.equals(((MapFieldValue) o).values); } @Override @@ -276,14 +282,24 @@ public class MapFieldValue<K extends FieldValue, V extends FieldValue> extends C return comp; } //types are equal, this must be of this type - MapFieldValue otherValue = (MapFieldValue) fieldValue; - comp = CollectionComparator.compare(values.keySet(), otherValue.values.keySet()); - - if (comp != 0) { - return comp; + MapFieldValue<K,V> rhs = (MapFieldValue<K,V>) fieldValue; + if (size() < rhs.size()) { + return -1; + } else if (size() > rhs.size()) { + return 1; + } + Map.Entry<K,V> [] entries = entrySet().toArray(new Map.Entry[size()]); + Map.Entry<K,V> [] rhsEntries = rhs.entrySet().toArray(new Map.Entry[rhs.size()]); + Arrays.sort(entries, (Map.Entry<K,V> a, Map.Entry<K,V> b) -> { return a.getKey().compareTo(b.getKey()); }); + Arrays.sort(rhsEntries, (Map.Entry<K,V> a, Map.Entry<K,V> b) -> { return a.getKey().compareTo(b.getKey()); }); + for (int i = 0; i < entries.length; i++) { + comp = entries[i].getKey().compareTo(rhsEntries[i].getKey()); + if (comp != 0) return comp; + comp = entries[i].getValue().compareTo(rhsEntries[i].getValue()); + if (comp != 0) return comp; } - return CollectionComparator.compare(values.values(), otherValue.values.values()); + return 0; } /** diff --git a/document/src/main/java/com/yahoo/document/datatypes/WeightedSet.java b/document/src/main/java/com/yahoo/document/datatypes/WeightedSet.java index 0e4c56406f0..63dc1cab063 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/WeightedSet.java +++ b/document/src/main/java/com/yahoo/document/datatypes/WeightedSet.java @@ -241,8 +241,7 @@ public final class WeightedSet<K extends FieldValue> extends CollectionFieldValu */ public boolean equals(Object o) { if (!(o instanceof WeightedSet)) return false; - WeightedSet otherSet = (WeightedSet) o; - return (super.equals(o) && map.equals(otherSet.map)); + return (super.equals(o) && map.equals(((WeightedSet<K>)o).map)); } /** @@ -293,15 +292,7 @@ public final class WeightedSet<K extends FieldValue> extends CollectionFieldValu return comp; } - //types are equal, this must be of this type - WeightedSet otherValue = (WeightedSet) fieldValue; - comp = CollectionComparator.compare(map.keySet(), otherValue.map.keySet()); - - if (comp != 0) { - return comp; - } - - return CollectionComparator.compare(map.values(), otherValue.map.values()); + return map.compareTo(((WeightedSet<K>)fieldValue).map); } diff --git a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java index 4f3d7d3b820..15319985591 100644 --- a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java +++ b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java @@ -709,6 +709,48 @@ public class DocumentUpdateTestCase { assertEquals(expected, doc.getFieldValue(field).getWrappedValue()); } + @Test + public void testThatClearCanBePrunedIfNoneExisting() { + Field field = docType.getField("strfoo"); + Document doc = createDocument(); + StringFieldValue expected = new StringFieldValue("some value"); + expected.clear(); + doc.setFieldValue(field, expected); + DocumentUpdate update = new DocumentUpdate(docType, new DocumentId(documentId)); + update.addFieldUpdate(FieldUpdate.createClearField(field)); + update.prune(doc); + assertEquals(0, update.size()); + update.applyTo(doc); + assertEquals(expected, doc.getFieldValue(field)); + } + + @Test + public void testThatClearCanBePrunedIfEmpty() { + Field field = docType.getField("strfoo"); + String expected = ""; + Document doc = createDocument(); + DocumentUpdate update = new DocumentUpdate(docType, new DocumentId(documentId)); + update.addFieldUpdate(FieldUpdate.createClearField(field)); + update.prune(doc); + assertEquals(0, update.size()); + update.applyTo(doc); + assertNull(doc.getFieldValue(field)); + } + + @Test + public void testThatClearCanBePrunedIfNoneExistingAndLast() { + Field field = docType.getField("strfoo"); + String expected = ""; + Document doc = createDocument(); + DocumentUpdate update = new DocumentUpdate(docType, new DocumentId(documentId)); + update.addFieldUpdate(FieldUpdate.createAssign(field, new StringFieldValue("some value"))); + update.addFieldUpdate(FieldUpdate.createClearField(field)); + update.prune(doc); + assertEquals(0, update.size()); + update.applyTo(doc); + assertNull(doc.getFieldValue(field)); + } + private static TensorFieldValue createTensorFieldValue(String tensor) { return new TensorFieldValue(Tensor.from(tensor)); } diff --git a/document/src/test/java/com/yahoo/document/datatypes/WeightedSetTestCase.java b/document/src/test/java/com/yahoo/document/datatypes/WeightedSetTestCase.java index 3436c73feae..107da479f72 100644 --- a/document/src/test/java/com/yahoo/document/datatypes/WeightedSetTestCase.java +++ b/document/src/test/java/com/yahoo/document/datatypes/WeightedSetTestCase.java @@ -2,13 +2,13 @@ package com.yahoo.document.datatypes; import com.yahoo.document.DataType; -import com.yahoo.document.MapDataType; import org.junit.Test; import java.util.LinkedHashMap; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -21,6 +21,60 @@ import static org.junit.Assert.fail; public class WeightedSetTestCase { @Test + public void testEquals() { + WeightedSet<StringFieldValue> a = new WeightedSet<>(DataType.TAG); + a.put(new StringFieldValue("this is a test"), 5); + a.put(new StringFieldValue("this is a second test"), 7); + + WeightedSet<StringFieldValue> b = new WeightedSet<>(DataType.TAG); + b.put(new StringFieldValue("this is a second test"), 7); + b.put(new StringFieldValue("this is a test"), 5); + + assertEquals(a, b); + assertEquals(0, a.compareTo(b)); + assertEquals(0, b.compareTo(a)); + + } + + @Test + public void testCompareTo() { + WeightedSet<StringFieldValue> a = new WeightedSet<>(DataType.TAG); + a.put(new StringFieldValue("this is a test"), 5); + a.put(new StringFieldValue("this is a second test"), 7); + + WeightedSet<StringFieldValue> b = new WeightedSet<>(DataType.TAG); + b.put(new StringFieldValue("this is a test"), 5); + + assertNotEquals(a, b); + assertEquals(1, a.compareTo(b)); + assertEquals(-1, b.compareTo(a)); + + b.clear(); + b.put(new StringFieldValue("this is a test"), 5); + b.put(new StringFieldValue("this is a third test"), 7); + + assertNotEquals(a, b); + assertEquals(-1, a.compareTo(b)); + assertEquals(1, b.compareTo(a)); + + b.clear(); + b.put(new StringFieldValue("this is a test"), 5); + b.put(new StringFieldValue("this is a second test"), 7); + + assertEquals(a, b); + assertEquals(0, a.compareTo(b)); + assertEquals(0, b.compareTo(a)); + + b.clear(); + b.put(new StringFieldValue("this is a test"), 5); + b.put(new StringFieldValue("this is a second test"), 6); + + assertNotEquals(a, b); + assertEquals(1, a.compareTo(b)); + assertEquals(-1, b.compareTo(a)); + } + + @Test public void testSet() { WeightedSet<StringFieldValue> wset = new WeightedSet<>(DataType.TAG); diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java index 42753338d06..eb71b3cfe47 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java @@ -26,7 +26,7 @@ import java.util.logging.Logger; * The sessions are multithread safe. * * @author bratseth - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar Rosenvinge</a> + * @author Einar Rosenvinge */ public class MessageBusAsyncSession implements MessageBusSession, AsyncSession { diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java index 7423792693b..dbf68106e07 100755 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java @@ -20,7 +20,7 @@ import java.util.List; public class ANDPolicy implements DocumentProtocolRoutingPolicy { // A list of hops that are to always be selected when select() is invoked. - private final List<Hop> hops = new ArrayList<Hop>(); + private final List<Hop> hops = new ArrayList<>(); /** * Constructs a new AND policy that requires all recipients to be ok for it to merge their replies to an ok reply. diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java index 82679e17990..a5b3accac68 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java @@ -51,7 +51,7 @@ public class MessageTypePolicy implements DocumentProtocolRoutingPolicy, ConfigS @Override public void configure(MessagetyperouteselectorpolicyConfig cfg) { - Map<Integer, Route> h = new HashMap<Integer, Route>(); + Map<Integer, Route> h = new HashMap<>(); for (MessagetyperouteselectorpolicyConfig.Route selector : cfg.route()) { h.put(selector.messagetype(), Route.parse(selector.name())); } diff --git a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerFactory.java b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerFactory.java index d8ea45e716d..e8a3038639a 100644 --- a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerFactory.java +++ b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerFactory.java @@ -4,10 +4,13 @@ package com.yahoo.filedistribution.fileacquirer; /** * Hides the real file acquirer type from 3rd party developers. * Not intended to be used by 3rd parties. + * * @author Tony Vaagenes */ public class FileAcquirerFactory { + public static FileAcquirer create(String configId) { return new FileAcquirerImpl(configId); } + } diff --git a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerImpl.java b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerImpl.java index fca4b206fc9..ab0f7521e7e 100644 --- a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerImpl.java +++ b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/FileAcquirerImpl.java @@ -23,12 +23,15 @@ import java.io.File; * @author Tony Vaagenes */ class FileAcquirerImpl implements FileAcquirer { + static final class FileDistributionErrorCode { + public static final int baseErrorCode = 0x10000; public static final int baseFileProviderErrorCode = baseErrorCode + 0x1000; public static final int fileReferenceDoesNotExists = baseFileProviderErrorCode; public static final int fileReferenceRemoved = fileReferenceDoesNotExists + 1; + } private static final Logger log = Logger.getLogger(FileAcquirerImpl.class.getName()); @@ -131,13 +134,10 @@ class FileAcquirerImpl implements FileAcquirer { * given file reference. File references are produced by the * config system. * - * @throws TimeoutException if the file or directory could not be - * retrieved in time. - * @throws FileReferenceDoesNotExistException if the file is no - * longer available (due to reloading of config). + * @throws TimeoutException if the file or directory could not be retrieved in time. + * @throws FileReferenceDoesNotExistException if the file is no longer available (due to reloading of config). */ - public File waitFor(FileReference fileReference, long timeout, TimeUnit timeUnit) - throws InterruptedException { + public File waitFor(FileReference fileReference, long timeout, TimeUnit timeUnit) throws InterruptedException { Timer timer = new Timer(timeout, timeUnit); do { Target target = connection.getTarget(timer); diff --git a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/MockFileAcquirer.java b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/MockFileAcquirer.java index 25732d2dcc8..1a8a05d0a53 100644 --- a/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/MockFileAcquirer.java +++ b/fileacquirer/src/main/java/com/yahoo/filedistribution/fileacquirer/MockFileAcquirer.java @@ -14,8 +14,9 @@ import java.util.concurrent.TimeUnit; * @author Tony Vaagenes */ public abstract class MockFileAcquirer implements FileAcquirer { + /** Creates a FileAcquirer that always returns the given file. **/ - public static FileAcquirer returnFile(final File file) { + public static FileAcquirer returnFile(File file) { return new MockFileAcquirer() { @Override public File waitFor(FileReference fileReference, @@ -26,7 +27,7 @@ public abstract class MockFileAcquirer implements FileAcquirer { } /** Creates a FileAcquirer that maps from fileReference.value to a file. **/ - public static FileAcquirer returnFiles(final Map<String, File> files) { + public static FileAcquirer returnFiles(Map<String, File> files) { return new MockFileAcquirer() { @Override public File waitFor(FileReference fileReference, @@ -60,4 +61,5 @@ public abstract class MockFileAcquirer implements FileAcquirer { @Override public void shutdown() {} + } diff --git a/fnet/src/examples/frt/rpc/rpc_callback_client.cpp b/fnet/src/examples/frt/rpc/rpc_callback_client.cpp index 801de59b515..7c6434e870a 100644 --- a/fnet/src/examples/frt/rpc/rpc_callback_client.cpp +++ b/fnet/src/examples/frt/rpc/rpc_callback_client.cpp @@ -26,7 +26,7 @@ RPC::Init(FRT_Supervisor *s) { FRT_ReflectionBuilder rb(s); //------------------------------------------------------------------- - rb.DefineMethod("prod", "", "", true, + rb.DefineMethod("prod", "", "", FRT_METHOD(RPC::Prod), this); //------------------------------------------------------------------- } @@ -45,6 +45,7 @@ MyApp::Main() printf("usage : rpc_server <connectspec>\n"); return 1; } + bool ok = true; RPC rpc; FRT_Supervisor orb; rpc.Init(&orb); @@ -63,6 +64,7 @@ MyApp::Main() printf("[error(%d): %s]\n", req->GetErrorCode(), req->GetErrorMessage()); + ok = false; } printf("invokeCnt: %d\n", rpc.invokeCnt); @@ -76,6 +78,7 @@ MyApp::Main() printf("[error(%d): %s]\n", req->GetErrorCode(), req->GetErrorMessage()); + ok = false; } printf("invokeCnt: %d\n", rpc.invokeCnt); @@ -89,14 +92,18 @@ MyApp::Main() printf("[error(%d): %s]\n", req->GetErrorCode(), req->GetErrorMessage()); + ok = false; } printf("invokeCnt: %d\n", rpc.invokeCnt); + if (rpc.invokeCnt != 3) { + ok = false; + } req->SubRef(); target->SubRef(); orb.ShutDown(true); - return 0; + return ok ? 0 : 1; } diff --git a/fnet/src/examples/frt/rpc/rpc_callback_server.cpp b/fnet/src/examples/frt/rpc/rpc_callback_server.cpp index 33207282bcc..ac7b34ebda0 100644 --- a/fnet/src/examples/frt/rpc/rpc_callback_server.cpp +++ b/fnet/src/examples/frt/rpc/rpc_callback_server.cpp @@ -2,6 +2,7 @@ #include <vespa/fnet/frt/frt.h> #include <vespa/fastos/app.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("rpc_callback_server"); @@ -12,9 +13,7 @@ struct RPC : public FRT_Invokable void Init(FRT_Supervisor *s); }; -void -RPC::CallBack(FRT_RPCRequest *req) -{ +void do_callback(FRT_RPCRequest *req) { FNET_Connection *conn = req->GetConnection(); FRT_RPCRequest *cb = new FRT_RPCRequest(); cb->SetMethodName(req->GetParams()->GetValue(0)._string._str); @@ -25,6 +24,14 @@ RPC::CallBack(FRT_RPCRequest *req) cb->GetErrorMessage()); } cb->SubRef(); + req->Return(); +} + +void +RPC::CallBack(FRT_RPCRequest *req) +{ + req->Detach(); + std::thread(do_callback, req).detach(); } void @@ -32,7 +39,7 @@ RPC::Init(FRT_Supervisor *s) { FRT_ReflectionBuilder rb(s); //------------------------------------------------------------------- - rb.DefineMethod("callBack", "s", "", false, + rb.DefineMethod("callBack", "s", "", FRT_METHOD(RPC::CallBack), this); //------------------------------------------------------------------- } diff --git a/fnet/src/examples/frt/rpc/rpc_server.cpp b/fnet/src/examples/frt/rpc/rpc_server.cpp index 8947663216e..03d618133c9 100644 --- a/fnet/src/examples/frt/rpc/rpc_server.cpp +++ b/fnet/src/examples/frt/rpc/rpc_server.cpp @@ -28,21 +28,21 @@ RPCServer::InitRPC(FRT_Supervisor *s) { FRT_ReflectionBuilder rb(s); //------------------------------------------------------------------- - rb.DefineMethod("concat", "ss", "s", true, + rb.DefineMethod("concat", "ss", "s", FRT_METHOD(RPCServer::RPC_concat), this); rb.MethodDesc("Concatenate two strings"); rb.ParamDesc("string1", "a string"); rb.ParamDesc("string2", "another string"); rb.ReturnDesc("ret", "the concatenation of string1 and string2"); //------------------------------------------------------------------- - rb.DefineMethod("addFloat", "ff", "f", true, + rb.DefineMethod("addFloat", "ff", "f", FRT_METHOD(RPCServer::RPC_addFloat), this); rb.MethodDesc("Add two floats"); rb.ParamDesc("float1", "a float"); rb.ParamDesc("float2", "another float"); rb.ReturnDesc("ret", "float1 + float2"); //------------------------------------------------------------------- - rb.DefineMethod("addDouble", "dd", "d", true, + rb.DefineMethod("addDouble", "dd", "d", FRT_METHOD(RPCServer::RPC_addDouble), this); rb.MethodDesc("Add two doubles"); rb.ParamDesc("double1", "a double"); diff --git a/fnet/src/examples/proxy/proxy.cpp b/fnet/src/examples/proxy/proxy.cpp index 653b445581f..a01a16ead9c 100644 --- a/fnet/src/examples/proxy/proxy.cpp +++ b/fnet/src/examples/proxy/proxy.cpp @@ -227,7 +227,6 @@ Proxy::Main() if (listener != nullptr) listener->SubRef(); - _transport.SetLogStats(true); FNET_SignalShutDown ssd(_transport); _transport.Main(); return 0; diff --git a/fnet/src/tests/frt/method_pt/method_pt.cpp b/fnet/src/tests/frt/method_pt/method_pt.cpp index db5905d6871..5417fddceeb 100644 --- a/fnet/src/tests/frt/method_pt/method_pt.cpp +++ b/fnet/src/tests/frt/method_pt/method_pt.cpp @@ -207,35 +207,35 @@ void initTest() { //------------------------------------------------------------------- - rb.DefineMethod("simpleMethod", "", "", true, + rb.DefineMethod("simpleMethod", "", "", FRT_METHOD(SimpleHandler::RPC_Method), _simpleHandler); //------------------------------------------------------------------- - rb.DefineMethod("mediumMethod1", "", "", true, + rb.DefineMethod("mediumMethod1", "", "", FRT_METHOD(MediumHandler1::RPC_Method), _mediumHandler1); - rb.DefineMethod("mediumMethod2", "", "", true, + rb.DefineMethod("mediumMethod2", "", "", FRT_METHOD(MediumHandler2::RPC_Method), _mediumHandler2); - rb.DefineMethod("mediumMethod3", "", "", true, + rb.DefineMethod("mediumMethod3", "", "", FRT_METHOD(MediumHandler3::RPC_Method), _mediumHandler3); //------------------------------------------------------------------- - rb.DefineMethod("complexMethod1", "", "", true, + rb.DefineMethod("complexMethod1", "", "", FRT_METHOD(ComplexHandler1::RPC_Method), _complexHandler1); - rb.DefineMethod("complexMethod2", "", "", true, + rb.DefineMethod("complexMethod2", "", "", FRT_METHOD(ComplexHandler2::RPC_Method), _complexHandler2); - rb.DefineMethod("complexMethod3", "", "", true, + rb.DefineMethod("complexMethod3", "", "", FRT_METHOD(ComplexHandler3::RPC_Method), _complexHandler3); diff --git a/fnet/src/tests/frt/parallel_rpc/parallel_rpc_test.cpp b/fnet/src/tests/frt/parallel_rpc/parallel_rpc_test.cpp index 4f8b4b82743..31aec84afd5 100644 --- a/fnet/src/tests/frt/parallel_rpc/parallel_rpc_test.cpp +++ b/fnet/src/tests/frt/parallel_rpc/parallel_rpc_test.cpp @@ -3,16 +3,19 @@ #include <vespa/vespalib/util/stringfmt.h> #include <vespa/fnet/frt/frt.h> #include <vespa/vespalib/util/benchmark_timer.h> +#include <vespa/vespalib/net/crypto_engine.h> +#include <vespa/vespalib/net/tls/tls_crypto_engine.h> +#include <vespa/vespalib/test/make_tls_options_for_testing.h> #include <thread> -using vespalib::BenchmarkTimer; +using namespace vespalib; struct Rpc : FRT_Invokable { FastOS_ThreadPool thread_pool; FNET_Transport transport; FRT_Supervisor orb; - Rpc(size_t num_threads) - : thread_pool(128 * 1024), transport(num_threads), orb(&transport, &thread_pool) {} + Rpc(CryptoEngine::SP crypto, size_t num_threads) + : thread_pool(128 * 1024), transport(crypto, num_threads), orb(&transport, &thread_pool) {} void start() { ASSERT_TRUE(transport.Start(&thread_pool)); } @@ -31,13 +34,13 @@ struct Rpc : FRT_Invokable { struct Server : Rpc { uint32_t port; - Server(size_t num_threads) : Rpc(num_threads), port(listen()) { + Server(CryptoEngine::SP crypto, size_t num_threads) : Rpc(crypto, num_threads), port(listen()) { init_rpc(); start(); } void init_rpc() { FRT_ReflectionBuilder rb(&orb); - rb.DefineMethod("inc", "l", "l", true, FRT_METHOD(Server::rpc_inc), this); + rb.DefineMethod("inc", "l", "l", FRT_METHOD(Server::rpc_inc), this); rb.MethodDesc("increment a 64-bit integer"); rb.ParamDesc("in", "an integer (64 bit)"); rb.ReturnDesc("out", "in + 1 (64 bit)"); @@ -51,7 +54,7 @@ struct Server : Rpc { struct Client : Rpc { uint32_t port; - Client(size_t num_threads, const Server &server) : Rpc(num_threads), port(server.port) { + Client(CryptoEngine::SP crypto, size_t num_threads, const Server &server) : Rpc(crypto, num_threads), port(server.port) { start(); } FRT_Target *connect() { return Rpc::connect(port); } @@ -114,10 +117,26 @@ void perform_test(size_t thread_id, Client &client, Result &result) { } } -TEST_MT_FFF("parallel rpc with 1/1 transport threads and 128 user threads", - 128, Server(1), Client(1, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } +CryptoEngine::SP null_crypto = std::make_shared<NullCryptoEngine>(); +CryptoEngine::SP xor_crypto = std::make_shared<XorCryptoEngine>(); +CryptoEngine::SP tls_crypto = std::make_shared<vespalib::TlsCryptoEngine>(vespalib::test::make_tls_options_for_testing()); -TEST_MT_FFF("parallel rpc with 8/8 transport threads and 128 user threads", - 128, Server(8), Client(8, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } +TEST_MT_FFF("parallel rpc with 1/1 transport threads and 128 user threads (no encryption)", + 128, Server(null_crypto, 1), Client(null_crypto, 1, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } + +TEST_MT_FFF("parallel rpc with 1/1 transport threads and 128 user threads (xor encryption)", + 128, Server(xor_crypto, 1), Client(xor_crypto, 1, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } + +TEST_MT_FFF("parallel rpc with 1/1 transport threads and 128 user threads (tls encryption)", + 128, Server(tls_crypto, 1), Client(tls_crypto, 1, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } + +TEST_MT_FFF("parallel rpc with 8/8 transport threads and 128 user threads (no encryption)", + 128, Server(null_crypto, 8), Client(null_crypto, 8, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } + +TEST_MT_FFF("parallel rpc with 8/8 transport threads and 128 user threads (xor encryption)", + 128, Server(xor_crypto, 8), Client(xor_crypto, 8, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } + +TEST_MT_FFF("parallel rpc with 8/8 transport threads and 128 user threads (tls encryption)", + 128, Server(tls_crypto, 8), Client(tls_crypto, 8, f1), Result(num_threads)) { perform_test(thread_id, f2, f3); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/fnet/src/tests/frt/rpc/CMakeLists.txt b/fnet/src/tests/frt/rpc/CMakeLists.txt index f935590ee77..2bacd37686a 100644 --- a/fnet/src/tests/frt/rpc/CMakeLists.txt +++ b/fnet/src/tests/frt/rpc/CMakeLists.txt @@ -7,6 +7,7 @@ vespa_add_executable(fnet_invoke_test_app TEST ) vespa_add_test(NAME fnet_invoke_test_app COMMAND fnet_invoke_test_app) vespa_add_test(NAME fnet_invoke_test_app_xor COMMAND fnet_invoke_test_app ENVIRONMENT "CRYPTOENGINE=xor") +vespa_add_test(NAME fnet_invoke_test_app_tls COMMAND fnet_invoke_test_app ENVIRONMENT "CRYPTOENGINE=tls") vespa_add_executable(fnet_detach_return_invoke_test_app TEST SOURCES detach_return_invoke.cpp @@ -22,6 +23,7 @@ vespa_add_executable(fnet_session_test_app TEST ) vespa_add_test(NAME fnet_session_test_app COMMAND fnet_session_test_app) vespa_add_test(NAME fnet_session_test_app_xor COMMAND fnet_session_test_app ENVIRONMENT "CRYPTOENGINE=xor") +vespa_add_test(NAME fnet_session_test_app_tls COMMAND fnet_session_test_app ENVIRONMENT "CRYPTOENGINE=tls") vespa_add_executable(fnet_sharedblob_test_app TEST SOURCES sharedblob.cpp diff --git a/fnet/src/tests/frt/rpc/detach_return_invoke.cpp b/fnet/src/tests/frt/rpc/detach_return_invoke.cpp index 54a891261c2..ab21c62bb68 100644 --- a/fnet/src/tests/frt/rpc/detach_return_invoke.cpp +++ b/fnet/src/tests/frt/rpc/detach_return_invoke.cpp @@ -20,7 +20,7 @@ struct Server : public FRT_Invokable Server(FRT_Supervisor &s, Receptor &r) : orb(s), receptor(r) { FRT_ReflectionBuilder rb(&s); - rb.DefineMethod("hook", "", "", true, + rb.DefineMethod("hook", "", "", FRT_METHOD(Server::rpc_hook), this); } diff --git a/fnet/src/tests/frt/rpc/invoke.cpp b/fnet/src/tests/frt/rpc/invoke.cpp index 787adb227f9..e3bd662214f 100644 --- a/fnet/src/tests/frt/rpc/invoke.cpp +++ b/fnet/src/tests/frt/rpc/invoke.cpp @@ -124,7 +124,7 @@ public: assert(_echo_stash != nullptr && _echo_args != nullptr); FRT_ReflectionBuilder rb(supervisor); - rb.DefineMethod("echo", "*", "*", true, + rb.DefineMethod("echo", "*", "*", FRT_METHOD(EchoTest::RPC_Echo), this); FRT_Values *args = _echo_args; @@ -225,17 +225,15 @@ public: { FRT_ReflectionBuilder rb(supervisor); - rb.DefineMethod("inc", "i", "i", true, + rb.DefineMethod("inc", "i", "i", FRT_METHOD(TestRPC::RPC_Inc), this); - rb.DefineMethod("setValue", "i", "", true, + rb.DefineMethod("setValue", "i", "", FRT_METHOD(TestRPC::RPC_SetValue), this); - rb.DefineMethod("incValue", "", "", true, + rb.DefineMethod("incValue", "", "", FRT_METHOD(TestRPC::RPC_IncValue), this); - rb.DefineMethod("getValue", "", "i", true, + rb.DefineMethod("getValue", "", "i", FRT_METHOD(TestRPC::RPC_GetValue), this); - rb.DefineMethod("testFast", "iiibb", "i", true, - FRT_METHOD(TestRPC::RPC_Test), this); - rb.DefineMethod("testSlow", "iiibb", "i", false, + rb.DefineMethod("testFast", "iiibb", "i", FRT_METHOD(TestRPC::RPC_Test), this); } @@ -364,7 +362,6 @@ const char phase_names[PHASE_ZZZ][32] = enum { TIMING_NULL = 0, TIMING_INSTANT, - TIMING_NON_INSTANT, TIMING_ZZZ }; @@ -372,7 +369,6 @@ const char timing_names[TIMING_ZZZ][32] = { "nullptr", "INSTANT", - "NON-INSTANT" }; enum { @@ -451,17 +447,10 @@ struct State { void PrepareTestMethod() { NewReq(); - bool instant = (_timing == TIMING_INSTANT); - if (_timing != TIMING_INSTANT && - _timing != TIMING_NON_INSTANT) - { + if (_timing != TIMING_INSTANT) { ASSERT_TRUE(false); // consult your dealer... } - if (instant) { - _req->SetMethodName("testFast"); - } else { - _req->SetMethodName("testSlow"); - } + _req->SetMethodName("testFast"); } void SetTestParams(uint32_t value, uint32_t delay, @@ -928,9 +917,9 @@ TEST_F("invoke test", State()) { EXPECT_TRUE(_phase_simple_cnt == 1); EXPECT_TRUE(_phase_void_cnt == 1); EXPECT_TRUE(_phase_speed_cnt == 1); - EXPECT_TRUE(_phase_advanced_cnt == 4); - EXPECT_TRUE(_phase_error_cnt == 4); - EXPECT_TRUE(_phase_abort_cnt == 4); + EXPECT_TRUE(_phase_advanced_cnt == 2); + EXPECT_TRUE(_phase_error_cnt == 2); + EXPECT_TRUE(_phase_abort_cnt == 2); EXPECT_TRUE(_phase_echo_cnt == 1); } diff --git a/fnet/src/tests/frt/rpc/my_crypto_engine.hpp b/fnet/src/tests/frt/rpc/my_crypto_engine.hpp index 6f573e5695a..6cd8d47e917 100644 --- a/fnet/src/tests/frt/rpc/my_crypto_engine.hpp +++ b/fnet/src/tests/frt/rpc/my_crypto_engine.hpp @@ -1,15 +1,21 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/vespalib/net/tls/tls_crypto_engine.h> +#include <vespa/vespalib/test/make_tls_options_for_testing.h> + vespalib::CryptoEngine::SP my_crypto_engine() { const char *env_str = getenv("CRYPTOENGINE"); if (!env_str) { - fprintf(stderr, "crypto engine: default\n"); - return vespalib::CryptoEngine::get_default(); + fprintf(stderr, "crypto engine: null\n"); + return std::make_shared<vespalib::NullCryptoEngine>(); } std::string engine(env_str); if (engine == "xor") { fprintf(stderr, "crypto engine: xor\n"); return std::make_shared<vespalib::XorCryptoEngine>(); + } else if (engine == "tls") { + fprintf(stderr, "crypto engine: tls\n"); + return std::make_shared<vespalib::TlsCryptoEngine>(vespalib::test::make_tls_options_for_testing()); } TEST_FATAL(("invalid crypto engine: " + engine).c_str()); abort(); diff --git a/fnet/src/tests/frt/rpc/session.cpp b/fnet/src/tests/frt/rpc/session.cpp index b84db9b4e88..93f14647e21 100644 --- a/fnet/src/tests/frt/rpc/session.cpp +++ b/fnet/src/tests/frt/rpc/session.cpp @@ -77,9 +77,9 @@ struct RPC : public FRT_Invokable void Init(FRT_Supervisor *s) { FRT_ReflectionBuilder rb(s); - rb.DefineMethod("getValue", "", "i", true, + rb.DefineMethod("getValue", "", "i", FRT_METHOD(RPC::GetValue), this); - rb.DefineMethod("setValue", "i", "", true, + rb.DefineMethod("setValue", "i", "", FRT_METHOD(RPC::SetValue), this); s->SetSessionInitHook(FRT_METHOD(RPC::InitSession), this); s->SetSessionFiniHook(FRT_METHOD(RPC::FiniSession), this); diff --git a/fnet/src/tests/frt/rpc/sharedblob.cpp b/fnet/src/tests/frt/rpc/sharedblob.cpp index 10eaad9c013..a48ecbb1da7 100644 --- a/fnet/src/tests/frt/rpc/sharedblob.cpp +++ b/fnet/src/tests/frt/rpc/sharedblob.cpp @@ -176,7 +176,7 @@ TEST("testImplicitShared") { ServerSampler serverSampler(dataSet, req); { FRT_ReflectionBuilder rb(&orb); - rb.DefineMethod("test", "*", "*", true, + rb.DefineMethod("test", "*", "*", FRT_METHOD(ServerSampler::RPC_test), &serverSampler); } orb.Listen(0); diff --git a/fnet/src/tests/info/info.cpp b/fnet/src/tests/info/info.cpp index f22c1402437..f76e66c2af6 100644 --- a/fnet/src/tests/info/info.cpp +++ b/fnet/src/tests/info/info.cpp @@ -24,7 +24,7 @@ struct RPC : public FRT_Invokable { FRT_ReflectionBuilder rb(s); //------------------------------------------------------------------- - rb.DefineMethod("getInfo", "", "sssii", true, + rb.DefineMethod("getInfo", "", "sssii", FRT_METHOD(RPC::GetInfo), this); // FastOS version // FNET version @@ -70,10 +70,10 @@ TEST("info") { TEST("size of important objects") { - EXPECT_EQUAL(176u, sizeof(FNET_IOComponent)); + EXPECT_EQUAL(168u, sizeof(FNET_IOComponent)); EXPECT_EQUAL(32u, sizeof(FNET_Channel)); EXPECT_EQUAL(40u, sizeof(FNET_PacketQueue_NoLock)); - EXPECT_EQUAL(480u, sizeof(FNET_Connection)); + EXPECT_EQUAL(472u, sizeof(FNET_Connection)); EXPECT_EQUAL(48u, sizeof(std::condition_variable)); EXPECT_EQUAL(56u, sizeof(FNET_DataBuffer)); EXPECT_EQUAL(24u, sizeof(FastOS_Time)); diff --git a/fnet/src/vespa/fnet/CMakeLists.txt b/fnet/src/vespa/fnet/CMakeLists.txt index 20badc5c489..4b9d818e5ed 100644 --- a/fnet/src/vespa/fnet/CMakeLists.txt +++ b/fnet/src/vespa/fnet/CMakeLists.txt @@ -17,7 +17,6 @@ vespa_add_library(fnet scheduler.cpp signalshutdown.cpp simplepacketstreamer.cpp - stats.cpp task.cpp transport.cpp transport_thread.cpp diff --git a/fnet/src/vespa/fnet/config.cpp b/fnet/src/vespa/fnet/config.cpp index feed7f2d241..a546d38f78b 100644 --- a/fnet/src/vespa/fnet/config.cpp +++ b/fnet/src/vespa/fnet/config.cpp @@ -3,11 +3,9 @@ #include "config.h" FNET_Config::FNET_Config() - : _minEventTimeOut(0), - _pingInterval(0), - _iocTimeOut(0), + : _iocTimeOut(0), _maxInputBufferSize(0x10000), _maxOutputBufferSize(0x10000), - _tcpNoDelay(true), - _logStats(false) -{ } + _tcpNoDelay(true) +{ +} diff --git a/fnet/src/vespa/fnet/config.h b/fnet/src/vespa/fnet/config.h index e94cf0f6105..3f34c1511b6 100644 --- a/fnet/src/vespa/fnet/config.h +++ b/fnet/src/vespa/fnet/config.h @@ -11,14 +11,10 @@ class FNET_Config { public: - uint32_t _minEventTimeOut; - uint32_t _pingInterval; uint32_t _iocTimeOut; uint32_t _maxInputBufferSize; uint32_t _maxOutputBufferSize; bool _tcpNoDelay; - bool _logStats; FNET_Config(); }; - diff --git a/fnet/src/vespa/fnet/connection.cpp b/fnet/src/vespa/fnet/connection.cpp index b9db8a46a4f..f2864d1dd58 100644 --- a/fnet/src/vespa/fnet/connection.cpp +++ b/fnet/src/vespa/fnet/connection.cpp @@ -110,13 +110,6 @@ FNET_Connection::SetState(State state) } if (oldstate < FNET_CLOSING && state >= FNET_CLOSING) { - if (_flags._writeLock) { - _flags._discarding = true; - while (_flags._writeLock) - _ioc_cond.wait(guard); - _flags._discarding = false; - } - while (!_queue.IsEmpty_NoLock() || !_myQueue.IsEmpty_NoLock()) { _flags._discarding = true; _queue.FlushPackets_NoLock(&_myQueue); @@ -233,14 +226,13 @@ FNET_Connection::handshake() EnableReadEvent(true); EnableWriteEvent(writePendingAfterConnect()); size_t chunk_size = std::max(size_t(FNET_READ_SIZE), _socket->min_read_buffer_size()); - uint32_t ignore_stats = 0; ssize_t res = 0; do { // drain input pipeline _input.EnsureFree(chunk_size); res = _socket->drain(_input.GetFree(), _input.GetFreeLen()); if (res > 0) { _input.FreeToData((uint32_t)res); - broken = !handle_packets(ignore_stats); + broken = !handle_packets(); _input.resetIfEmpty(); } } while ((res > 0) && !broken); } @@ -258,7 +250,7 @@ FNET_Connection::handshake() } bool -FNET_Connection::handle_packets(uint32_t &read_packets) +FNET_Connection::handle_packets() { bool broken = false; for (bool done = false; !done;) { // handle each complete packet in the buffer. @@ -268,7 +260,6 @@ FNET_Connection::handle_packets(uint32_t &read_packets) &broken); } if (_flags._gotheader && (_input.GetDataLen() >= _packetLength)) { - read_packets++; HandlePacket(_packetLength, _packetCode, _packetCHID); _flags._gotheader = false; // reset header flag. } else { @@ -282,26 +273,26 @@ bool FNET_Connection::Read() { size_t chunk_size = std::max(size_t(FNET_READ_SIZE), _socket->min_read_buffer_size()); - uint32_t readData = 0; // total data read - uint32_t readPackets = 0; // total packets read int readCnt = 0; // read count bool broken = false; // is this conn broken ? + int my_errno = 0; // sample and preserve errno ssize_t res; // single read result _input.EnsureFree(chunk_size); res = _socket->read(_input.GetFree(), _input.GetFreeLen()); + my_errno = errno; readCnt++; while (res > 0) { _input.FreeToData((uint32_t)res); - readData += (uint32_t)res; - broken = !handle_packets(readPackets); + broken = !handle_packets(); _input.resetIfEmpty(); if (broken || (_input.GetFreeLen() > 0) || (readCnt >= FNET_READ_REDO)) { goto done_read; } _input.EnsureFree(chunk_size); res = _socket->read(_input.GetFree(), _input.GetFreeLen()); + my_errno = errno; readCnt++; } @@ -310,28 +301,24 @@ done_read: while ((res > 0) && !broken) { // drain input pipeline _input.EnsureFree(chunk_size); res = _socket->drain(_input.GetFree(), _input.GetFreeLen()); + my_errno = errno; readCnt++; if (res > 0) { _input.FreeToData((uint32_t)res); - readData += (uint32_t)res; - broken = !handle_packets(readPackets); + broken = !handle_packets(); _input.resetIfEmpty(); } else if (res == 0) { // fully drained -> EWOULDBLOCK - errno = EWOULDBLOCK; + my_errno = EWOULDBLOCK; res = -1; } } - if (readData > 0) { - UpdateTimeOut(); - CountDataRead(readData); - CountPacketRead(readPackets); - uint32_t maxSize = GetConfig()->_maxInputBufferSize; - if (maxSize > 0 && _input.GetBufSize() > maxSize) - { - if (!_flags._gotheader || _packetLength < maxSize) { - _input.Shrink(maxSize); - } + UpdateTimeOut(); + uint32_t maxSize = GetConfig()->_maxInputBufferSize; + if (maxSize > 0 && _input.GetBufSize() > maxSize) + { + if (!_flags._gotheader || _packetLength < maxSize) { + _input.Shrink(maxSize); } } @@ -339,9 +326,9 @@ done_read: if (res == 0) { broken = true; // handle EOF } else { // res < 0 - broken = ((errno != EWOULDBLOCK) && (errno != EAGAIN)); - if (broken && (errno != ECONNRESET)) { - LOG(debug, "Connection(%s): read error: %d", GetSpec(), errno); + broken = ((my_errno != EWOULDBLOCK) && (my_errno != EAGAIN)); + if (broken && (my_errno != ECONNRESET)) { + LOG(debug, "Connection(%s): read error: %d", GetSpec(), my_errno); } } } @@ -354,10 +341,9 @@ bool FNET_Connection::Write() { uint32_t my_write_work = 0; - uint32_t writtenData = 0; // total data written - uint32_t writtenPackets = 0; // total packets written int writeCnt = 0; // write count bool broken = false; // is this conn broken ? + int my_errno = 0; // sample and preserve errno ssize_t res; // single write result FNET_Packet *packet; @@ -374,7 +360,6 @@ FNET_Connection::Write() packet = _myQueue.DequeuePacket_NoLock(&context); if (packet->IsRegularPacket()) { // ignore non-regular packets _streamer->Encode(packet, context._value.INT, &_output); - writtenPackets++; } packet->Free(); } @@ -387,10 +372,10 @@ FNET_Connection::Write() // write data res = _socket->write(_output.GetData(), _output.GetDataLen()); + my_errno = errno; writeCnt++; if (res > 0) { _output.DataToDead((uint32_t)res); - writtenData += (uint32_t)res; _output.resetIfEmpty(); } } while (res > 0 && @@ -404,26 +389,26 @@ FNET_Connection::Write() if (res >= 0) { // flush output pipeline res = _socket->flush(); + my_errno = errno; while (res > 0) { res = _socket->flush(); + my_errno = errno; } } - if (writtenData > 0) { - uint32_t maxSize = GetConfig()->_maxOutputBufferSize; - if (maxSize > 0 && _output.GetBufSize() > maxSize) { - _output.Shrink(maxSize); - } + uint32_t maxSize = GetConfig()->_maxOutputBufferSize; + if (maxSize > 0 && _output.GetBufSize() > maxSize) { + _output.Shrink(maxSize); } if (res < 0) { - if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) { + if ((my_errno == EWOULDBLOCK) || (my_errno == EAGAIN)) { ++my_write_work; // incomplete write/flush } else { broken = true; } - if (broken && (errno != ECONNRESET)) { - LOG(debug, "Connection(%s): write error: %d", GetSpec(), errno); + if (broken && (my_errno != ECONNRESET)) { + LOG(debug, "Connection(%s): write error: %d", GetSpec(), my_errno); } } @@ -431,17 +416,9 @@ FNET_Connection::Write() _writeWork = _queue.GetPacketCnt_NoLock() + _myQueue.GetPacketCnt_NoLock() + my_write_work; - _flags._writeLock = false; - if (_flags._discarding) { - _ioc_cond.notify_all(); - } bool writePending = (_writeWork > 0); guard.unlock(); - if (writtenData > 0) { - CountDataWrite(writtenData); - CountPacketWrite(writtenPackets); - } if (!writePending) EnableWriteEvent(false); @@ -528,7 +505,6 @@ FNET_Connection::~FNET_Connection() delete _adminChannel; } assert(_cleanup == nullptr); - assert(!_flags._writeLock); } @@ -682,9 +658,7 @@ FNET_Connection::PostPacket(FNET_Packet *packet, uint32_t chid) writeWork = _writeWork; _writeWork++; _queue.QueuePacket_NoLock(packet, FNET_Context(chid)); - if (writeWork == 0 && !_flags._writeLock && - _state == FNET_CONNECTED) - { + if ((writeWork == 0) && (_state == FNET_CONNECTED)) { AddRef_NoLock(); guard.unlock(); Owner()->EnableWrite(this, /* needRef = */ false); @@ -693,14 +667,6 @@ FNET_Connection::PostPacket(FNET_Packet *packet, uint32_t chid) } -uint32_t -FNET_Connection::GetQueueLen() -{ - std::lock_guard<std::mutex> guard(_ioc_lock); - return _queue.GetPacketCnt_NoLock() + _myQueue.GetPacketCnt_NoLock(); -} - - void FNET_Connection::Sync() { @@ -774,12 +740,6 @@ FNET_Connection::HandleWriteEvent() case FNET_CONNECTED: { std::unique_lock<std::mutex> guard(_ioc_lock); - if (_flags._writeLock) { - guard.unlock(); - EnableWriteEvent(false); - return true; - } - _flags._writeLock = true; _queue.FlushPackets_NoLock(&_myQueue); } broken = !Write(); diff --git a/fnet/src/vespa/fnet/connection.h b/fnet/src/vespa/fnet/connection.h index 8e5e6280fab..8e275d68b18 100644 --- a/fnet/src/vespa/fnet/connection.h +++ b/fnet/src/vespa/fnet/connection.h @@ -68,13 +68,11 @@ private: struct Flags { Flags() : _gotheader(false), - _writeLock(false), _inCallback(false), _callbackWait(false), _discarding(false) { } bool _gotheader; - bool _writeLock; bool _inCallback; bool _callbackWait; bool _discarding; @@ -212,9 +210,8 @@ private: * for each one. * * @return false if socket is broken. - * @param read_packets count read packets here **/ - bool handle_packets(uint32_t &read_packets); + bool handle_packets(); /** * Read incoming data from socket. @@ -450,19 +447,6 @@ public: /** - * Obtain the number of packets located in the output queue for this - * connection. Note that this number is volatile and should only be - * used as an estimate. Also note that since a queue latching - * strategy is used, this method requires a mutex lock/unlock and is - * therefore not as cheap as may be expected. - * - * @return number of packets currently located in the output queue - * for this connection. - **/ - uint32_t GetQueueLen(); - - - /** * Sync with this connection. When this method is invoked it will * block until all packets currently posted on this connection is * encoded into the output buffer. Also, the amount of data in the diff --git a/fnet/src/vespa/fnet/fnet.h b/fnet/src/vespa/fnet/fnet.h index 5a3a8b28942..c7570e025ec 100644 --- a/fnet/src/vespa/fnet/fnet.h +++ b/fnet/src/vespa/fnet/fnet.h @@ -32,8 +32,6 @@ class FNET_Packet; class FNET_PacketQueue; class FNET_Scheduler; class FNET_SimplePacketStreamer; -class FNET_StatCounters; -class FNET_Stats; class FNET_Task; class FNET_Transport; class FNET_TransportThread; @@ -52,7 +50,6 @@ class FNET_TransportThread; #include "task.h" #include "scheduler.h" #include "config.h" -#include "stats.h" #include "databuffer.h" #include "packet.h" #include "dummypacket.h" diff --git a/fnet/src/vespa/fnet/frt/invoker.cpp b/fnet/src/vespa/fnet/frt/invoker.cpp index f2dc331c707..b174c3a710e 100644 --- a/fnet/src/vespa/fnet/frt/invoker.cpp +++ b/fnet/src/vespa/fnet/frt/invoker.cpp @@ -64,18 +64,14 @@ FRT_RPCInvoker::FRT_RPCInvoker(FRT_Supervisor *supervisor, req->SetReturnHandler(this); } -bool FRT_RPCInvoker::IsInstant() { - return _method->IsInstant(); -} - -bool FRT_RPCInvoker::Invoke(bool freeChannel) +bool FRT_RPCInvoker::Invoke() { bool detached = false; _req->SetDetachedPT(&detached); (_method->GetHandler()->*_method->GetMethod())(_req); if (detached) return false; - HandleDone(freeChannel); + HandleDone(false); return true; } @@ -120,13 +116,6 @@ FRT_RPCInvoker::GetConnection() return _req->GetContext()._value.CHANNEL->GetConnection(); } - -void -FRT_RPCInvoker::Run(FastOS_ThreadInterface *, void *) -{ - Invoke(true); -} - //----------------------------------------------------------------------------- void FRT_HookInvoker::Invoke() diff --git a/fnet/src/vespa/fnet/frt/invoker.h b/fnet/src/vespa/fnet/frt/invoker.h index 15d74017200..64adf66688e 100644 --- a/fnet/src/vespa/fnet/frt/invoker.h +++ b/fnet/src/vespa/fnet/frt/invoker.h @@ -59,8 +59,7 @@ public: //----------------------------------------------------------------------------- -class FRT_RPCInvoker : public FastOS_Runnable, - public FRT_IReturnHandler +class FRT_RPCInvoker : public FRT_IReturnHandler { private: FRT_RPCRequest *_req; @@ -76,15 +75,13 @@ public: bool noReply); void ForceMethod(FRT_Method *method) { _method = method; } - bool IsInstant(); FRT_RPCRequest *GetRequest() { return _req; } void HandleDone(bool freeChannel); - bool Invoke(bool freeChannel); + bool Invoke(); void HandleReturn() override; FNET_Connection *GetConnection() override; - void Run(FastOS_ThreadInterface *, void *) override; }; //----------------------------------------------------------------------------- diff --git a/fnet/src/vespa/fnet/frt/reflection.cpp b/fnet/src/vespa/fnet/frt/reflection.cpp index 4285c512ebf..305294f4a3c 100644 --- a/fnet/src/vespa/fnet/frt/reflection.cpp +++ b/fnet/src/vespa/fnet/frt/reflection.cpp @@ -6,13 +6,12 @@ #include "supervisor.h" FRT_Method::FRT_Method(const char * name, const char * paramSpec, const char * returnSpec, - bool instant, FRT_METHOD_PT method, FRT_Invokable * handler) + FRT_METHOD_PT method, FRT_Invokable * handler) : _hashNext(nullptr), _listNext(nullptr), _name(strdup(name)), _paramSpec(strdup(paramSpec)), _returnSpec(strdup(returnSpec)), - _instant(instant), _method(method), _handler(handler), _docLen(0), @@ -171,7 +170,6 @@ void FRT_ReflectionBuilder::DefineMethod(const char *name, const char *paramSpec, const char *returnSpec, - bool instant, FRT_METHOD_PT method, FRT_Invokable *handler) { @@ -182,7 +180,6 @@ FRT_ReflectionBuilder::DefineMethod(const char *name, _method = new FRT_Method(name, paramSpec, returnSpec, - instant, method, handler); _lookup->AddMethod(_method); diff --git a/fnet/src/vespa/fnet/frt/reflection.h b/fnet/src/vespa/fnet/frt/reflection.h index 466e58413e9..5189cf81d0a 100644 --- a/fnet/src/vespa/fnet/frt/reflection.h +++ b/fnet/src/vespa/fnet/frt/reflection.h @@ -19,7 +19,6 @@ private: char *_name; // method name char *_paramSpec; // method parameter spec char *_returnSpec; // method return spec - bool _instant; // method is instant ? FRT_METHOD_PT _method; // method pointer FRT_Invokable *_handler; // method handler uint32_t _docLen; // method documentation length @@ -32,7 +31,6 @@ public: FRT_Method(const char *name, const char *paramSpec, const char *returnSpec, - bool instant, FRT_METHOD_PT method, FRT_Invokable *handler); @@ -42,7 +40,6 @@ public: const char *GetName() { return _name; } const char *GetParamSpec() { return _paramSpec; } const char *GetReturnSpec() { return _returnSpec; } - bool IsInstant() { return _instant; } FRT_METHOD_PT GetMethod() { return _method; } FRT_Invokable *GetHandler() { return _handler; } void SetDocumentation(FRT_Values *values); @@ -121,7 +118,6 @@ public: void DefineMethod(const char *name, const char *paramSpec, const char *returnSpec, - bool instant, FRT_METHOD_PT method, FRT_Invokable *handler); void MethodDesc(const char *desc); diff --git a/fnet/src/vespa/fnet/frt/rpcrequest.h b/fnet/src/vespa/fnet/frt/rpcrequest.h index a10653ce2f6..cc871e7ac0c 100644 --- a/fnet/src/vespa/fnet/frt/rpcrequest.h +++ b/fnet/src/vespa/fnet/frt/rpcrequest.h @@ -133,7 +133,7 @@ public: FNET_Packet *CreateReplyPacket(); void SetDetachedPT(bool *detachedPT) { _detachedPT = detachedPT; } - void Detach() { *_detachedPT = true; } + FRT_RPCRequest *Detach() { *_detachedPT = true; return this; } void SetAbortHandler(FRT_IAbortHandler *handler) { _abortHandler = handler; } void SetReturnHandler(FRT_IReturnHandler *handler) { _returnHandler = handler; } diff --git a/fnet/src/vespa/fnet/frt/supervisor.cpp b/fnet/src/vespa/fnet/frt/supervisor.cpp index 927e2e84b94..e509223c005 100644 --- a/fnet/src/vespa/fnet/frt/supervisor.cpp +++ b/fnet/src/vespa/fnet/frt/supervisor.cpp @@ -91,22 +91,6 @@ FRT_Supervisor::GetListenPort() const } -bool -FRT_Supervisor::RunInvocation(FRT_RPCInvoker *invoker) -{ - // XXX: implement queue with max length + max # threads - - if (_threadPool == nullptr || - _threadPool->NewThread(invoker) == nullptr) - { - invoker->GetRequest()->SetError(FRTE_RPC_OVERLOAD, - "Could not start thread"); - return false; - } - return true; -} - - FRT_Target * FRT_Supervisor::GetTarget(const char *spec) { @@ -179,7 +163,7 @@ FRT_Supervisor::SetMethodMismatchHook(FRT_METHOD_PT method, { delete _methodMismatchHook; _methodMismatchHook = new FRT_Method("frt.hook.methodMismatch", "*", "*", - true, method, handler); + method, handler); assert(_methodMismatchHook != nullptr); } @@ -284,25 +268,17 @@ FRT_Supervisor::HandlePacket(FNET_Packet *packet, FNET_Context context) && _methodMismatchHook != nullptr) { invoker->ForceMethod(_methodMismatchHook); - return (invoker->Invoke(false)) ? + return (invoker->Invoke()) ? FNET_FREE_CHANNEL : FNET_CLOSE_CHANNEL; } invoker->HandleDone(false); return FNET_FREE_CHANNEL; - } else if (invoker->IsInstant()) { - - return (invoker->Invoke(false)) ? - FNET_FREE_CHANNEL : FNET_CLOSE_CHANNEL; - } else { - if (!RunInvocation(invoker)) { - invoker->HandleDone(false); - return FNET_FREE_CHANNEL; - } - return FNET_CLOSE_CHANNEL; + return (invoker->Invoke()) ? + FNET_FREE_CHANNEL : FNET_CLOSE_CHANNEL; } } @@ -349,17 +325,17 @@ FRT_Supervisor::RPCHooks::InitRPC(FRT_Supervisor *supervisor) { FRT_ReflectionBuilder rb(supervisor); //--------------------------------------------------------------------------- - rb.DefineMethod("frt.rpc.ping", "", "", true, + rb.DefineMethod("frt.rpc.ping", "", "", FRT_METHOD(FRT_Supervisor::RPCHooks::RPC_Ping), this); rb.MethodDesc("Method that may be used to check if the server is online"); //--------------------------------------------------------------------------- - rb.DefineMethod("frt.rpc.echo", "*", "*", true, + rb.DefineMethod("frt.rpc.echo", "*", "*", FRT_METHOD(FRT_Supervisor::RPCHooks::RPC_Echo), this); rb.MethodDesc("Echo the parameters as return values"); rb.ParamDesc("params", "Any set of parameters"); rb.ReturnDesc("return", "The parameter values"); //--------------------------------------------------------------------------- - rb.DefineMethod("frt.rpc.getMethodList", "", "SSS", true, + rb.DefineMethod("frt.rpc.getMethodList", "", "SSS", FRT_METHOD(FRT_Supervisor::RPCHooks::RPC_GetMethodList), this); rb.MethodDesc("Obtain a list of all available methods"); @@ -367,7 +343,7 @@ FRT_Supervisor::RPCHooks::InitRPC(FRT_Supervisor *supervisor) rb.ReturnDesc("params", "Method parameter types"); rb.ReturnDesc("return", "Method return types"); //--------------------------------------------------------------------------- - rb.DefineMethod("frt.rpc.getMethodInfo", "s", "sssSSSS", true, + rb.DefineMethod("frt.rpc.getMethodInfo", "s", "sssSSSS", FRT_METHOD(FRT_Supervisor::RPCHooks::RPC_GetMethodInfo), this); rb.MethodDesc("Obtain detailed information about a single method"); @@ -448,7 +424,7 @@ FRT_Supervisor::ConnHooks::SetSessionInitHook(FRT_METHOD_PT method, { delete _sessionInitHook; _sessionInitHook = new FRT_Method("frt.hook.sessionInit", "", "", - true, method, handler); + method, handler); assert(_sessionInitHook != nullptr); } @@ -459,7 +435,7 @@ FRT_Supervisor::ConnHooks::SetSessionDownHook(FRT_METHOD_PT method, { delete _sessionDownHook; _sessionDownHook = new FRT_Method("frt.hook.sessionDown", "", "", - true, method, handler); + method, handler); assert(_sessionDownHook != nullptr); } @@ -470,7 +446,7 @@ FRT_Supervisor::ConnHooks::SetSessionFiniHook(FRT_METHOD_PT method, { delete _sessionFiniHook; _sessionFiniHook = new FRT_Method("frt.hook.sessionFini", "", "", - true, method, handler); + method, handler); assert(_sessionFiniHook != nullptr); } diff --git a/fnet/src/vespa/fnet/frt/supervisor.h b/fnet/src/vespa/fnet/frt/supervisor.h index 051c1caceeb..dc7fb496239 100644 --- a/fnet/src/vespa/fnet/frt/supervisor.h +++ b/fnet/src/vespa/fnet/frt/supervisor.h @@ -99,8 +99,6 @@ public: bool Listen(int port); uint32_t GetListenPort() const; - bool RunInvocation(FRT_RPCInvoker *invoker); - FRT_Target *GetTarget(const char *spec); FRT_Target *Get2WayTarget(const char *spec, FNET_Context connContext = FNET_Context()); diff --git a/fnet/src/vespa/fnet/iocomponent.cpp b/fnet/src/vespa/fnet/iocomponent.cpp index 148dabf5c60..d4244cbf204 100644 --- a/fnet/src/vespa/fnet/iocomponent.cpp +++ b/fnet/src/vespa/fnet/iocomponent.cpp @@ -12,7 +12,6 @@ FNET_IOComponent::FNET_IOComponent(FNET_TransportThread *owner, : _ioc_next(nullptr), _ioc_prev(nullptr), _ioc_owner(owner), - _ioc_counters(_ioc_owner->GetStatCounters()), _ioc_socket_fd(socket_fd), _ioc_selector(nullptr), _ioc_spec(nullptr), diff --git a/fnet/src/vespa/fnet/iocomponent.h b/fnet/src/vespa/fnet/iocomponent.h index 16ecce2e345..901c3d1a5d0 100644 --- a/fnet/src/vespa/fnet/iocomponent.h +++ b/fnet/src/vespa/fnet/iocomponent.h @@ -2,14 +2,12 @@ #pragma once -#include "stats.h" #include <vespa/fastos/timestamp.h> #include <vespa/vespalib/net/selector.h> #include <mutex> #include <condition_variable> class FNET_TransportThread; -class FNET_StatCounters; class FNET_Config; /** @@ -45,7 +43,6 @@ protected: FNET_IOComponent *_ioc_next; // next in list FNET_IOComponent *_ioc_prev; // prev in list FNET_TransportThread *_ioc_owner; // owner(TransportThread) ref. - FNET_StatCounters *_ioc_counters; // stat counters int _ioc_socket_fd; // source of events. Selector *_ioc_selector; // attached event selector char *_ioc_spec; // connect/listen spec @@ -154,47 +151,6 @@ public: **/ void UpdateTimeOut(); - - /** - * Count packet read(s). This is a proxy method updating the stat - * counters associated with the owning transport object. - * - * @param cnt the number of packets read (default is 1). - **/ - void CountPacketRead(uint32_t cnt = 1) - { _ioc_counters->CountPacketRead(cnt); } - - - /** - * Count packet write(s). This is a proxy method updating the stat - * counters associated with the owning transport object. - * - * @param cnt the number of packets written (default is 1). - **/ - void CountPacketWrite(uint32_t cnt = 1) - { _ioc_counters->CountPacketWrite(cnt); } - - - /** - * Count read data. This is a proxy method updating the stat - * counters associated with the owning transport object. - * - * @param bytes the number of bytes read. - **/ - void CountDataRead(uint32_t bytes) - { _ioc_counters->CountDataRead(bytes); } - - - /** - * Count written data. This is a proxy method updating the stat - * counters associated with the owning transport object. - * - * @param bytes the number of bytes written. - **/ - void CountDataWrite(uint32_t bytes) - { _ioc_counters->CountDataWrite(bytes); } - - /** * Attach an event selector to this component. Before deleting an * IOC, one must first call detach_selector to detach the diff --git a/fnet/src/vespa/fnet/stats.cpp b/fnet/src/vespa/fnet/stats.cpp deleted file mode 100644 index f156fe4afe7..00000000000 --- a/fnet/src/vespa/fnet/stats.cpp +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include "stats.h" - -#include <vespa/log/log.h> -LOG_SETUP(".fnet"); - -FNET_StatCounters::FNET_StatCounters() - : _eventLoopCnt(0), - _eventCnt(0), - _ioEventCnt(0), - _packetReadCnt(0), - _packetWriteCnt(0), - _dataReadCnt(0), - _dataWriteCnt(0) -{ -} - - -FNET_StatCounters::~FNET_StatCounters() -{ -} - - -void -FNET_StatCounters::Clear() -{ - _eventLoopCnt = 0; - _eventCnt = 0; - _ioEventCnt = 0; - _packetReadCnt = 0; - _packetWriteCnt = 0; - _dataReadCnt = 0; - _dataWriteCnt = 0; -} - -//----------------------------------------------- - -FNET_Stats::FNET_Stats() - : _eventLoopRate(0), - _eventRate(0), - _ioEventRate(0), - _packetReadRate(0), - _packetWriteRate(0), - _dataReadRate(0), - _dataWriteRate(0) -{ -} - - -FNET_Stats::~FNET_Stats() -{ -} - - -void -FNET_Stats::Update(FNET_StatCounters *count, double secs) -{ - _eventLoopRate = (float)(FNET_STATS_OLD_FACTOR * _eventLoopRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_eventLoopCnt / secs))); - _eventRate = (float)(FNET_STATS_OLD_FACTOR * _eventRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_eventCnt / secs))); - _ioEventRate = (float)(FNET_STATS_OLD_FACTOR * _ioEventRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_ioEventCnt / secs))); - - _packetReadRate = (float)(FNET_STATS_OLD_FACTOR * _packetReadRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_packetReadCnt / secs))); - _packetWriteRate = (float)(FNET_STATS_OLD_FACTOR * _packetWriteRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_packetWriteCnt / secs))); - - _dataReadRate = (float)(FNET_STATS_OLD_FACTOR * _dataReadRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_dataReadCnt / (1000.0 * secs)))); - _dataWriteRate = (float)(FNET_STATS_OLD_FACTOR * _dataWriteRate - + (FNET_STATS_NEW_FACTOR - * ((double)count->_dataWriteCnt / (1000.0 * secs)))); -} - - -void -FNET_Stats::Log() -{ - LOG(info, "events[/s][loop/int/io][%.1f/%.1f/%.1f] " - "packets[/s][r/w][%.1f/%.1f] " - "data[kB/s][r/w][%.2f/%.2f]", - _eventLoopRate, - _eventRate, - _ioEventRate, - _packetReadRate, - _packetWriteRate, - _dataReadRate, - _dataWriteRate); -} diff --git a/fnet/src/vespa/fnet/stats.h b/fnet/src/vespa/fnet/stats.h deleted file mode 100644 index 76651393165..00000000000 --- a/fnet/src/vespa/fnet/stats.h +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#pragma once - -#include <cstdint> - -/** - * This class is used internally by @ref FNET_Transport objects to - * aggregate FNET statistics. The actual statistics are located in the - * @ref FNET_Stats class. - **/ -class FNET_StatCounters -{ -public: - uint32_t _eventLoopCnt; // # event loop iterations - uint32_t _eventCnt; // # internal events - uint32_t _ioEventCnt; // # IO events - uint32_t _packetReadCnt; // # packets read - uint32_t _packetWriteCnt; // # packets written - uint32_t _dataReadCnt; // # bytes read - uint32_t _dataWriteCnt; // # bytes written - - FNET_StatCounters(); - ~FNET_StatCounters(); - - void Clear(); - void CountEventLoop(uint32_t cnt) { _eventLoopCnt += cnt; } - void CountEvent(uint32_t cnt) { _eventCnt += cnt; } - void CountIOEvent(uint32_t cnt) { _ioEventCnt += cnt; } - void CountPacketRead(uint32_t cnt) { _packetReadCnt += cnt; } - void CountPacketWrite(uint32_t cnt) { _packetWriteCnt += cnt; } - void CountDataRead(uint32_t bytes) { _dataReadCnt += bytes; } - void CountDataWrite(uint32_t bytes) { _dataWriteCnt += bytes; } -}; - -//----------------------------------------------- - -#define FNET_STATS_OLD_FACTOR 0.5 -#define FNET_STATS_NEW_FACTOR 0.5 - -/** - * This class contains various FNET statistics. The statistics for a - * @ref FNET_Transport object may be obtained by invoking the GetStats - * method on it. - **/ -class FNET_Stats -{ -public: - /** - * Event loop iterations per second. - **/ - float _eventLoopRate; // loop iterations/s - - /** - * Internal events handled per second. - **/ - float _eventRate; // internal-events/s - - /** - * IO events handled per second. - **/ - float _ioEventRate; // IO-events/s - - /** - * Packets read per second. - **/ - float _packetReadRate; // packets/s - - /** - * Packets written per second. - **/ - float _packetWriteRate; // packets/s - - /** - * Data read per second (in kB). - **/ - float _dataReadRate; // kB/s - - /** - * Data written per second (in kB). - **/ - float _dataWriteRate; // kB/s - - FNET_Stats(); - ~FNET_Stats(); - - /** - * Update statistics. The new statistics are calculated based on - * both the current values and the input count structure indicating - * what has happened since the last statistics update. - * - * @param count what has happened since last statistics update. - * @param secs number of seconds since last statistics update. - **/ - void Update(FNET_StatCounters *count, double secs); - - /** - * Invoking this method will generate a log message of type - * FNET_INFO showing the values held by this object. - **/ - void Log(); -}; - diff --git a/fnet/src/vespa/fnet/transport.cpp b/fnet/src/vespa/fnet/transport.cpp index 55c847682c4..dfeb8d03436 100644 --- a/fnet/src/vespa/fnet/transport.cpp +++ b/fnet/src/vespa/fnet/transport.cpp @@ -120,14 +120,6 @@ FNET_Transport::SetTCPNoDelay(bool noDelay) } void -FNET_Transport::SetLogStats(bool logStats) -{ - for (const auto &thread: _threads) { - thread->SetLogStats(logStats); - } -} - -void FNET_Transport::sync() { for (const auto &thread: _threads) { diff --git a/fnet/src/vespa/fnet/transport.h b/fnet/src/vespa/fnet/transport.h index dbf914798fd..15e69bd66a6 100644 --- a/fnet/src/vespa/fnet/transport.h +++ b/fnet/src/vespa/fnet/transport.h @@ -195,14 +195,6 @@ public: void SetTCPNoDelay(bool noDelay); /** - * Enable or disable logging of FNET statistics. This feature is - * disabled by default. - * - * @param logStats true if stats should be logged. - **/ - void SetLogStats(bool logStats); - - /** * Synchronize with all transport threads. This method will block * until all events posted before this method was invoked has been * processed. If a transport thread has been shut down (or is in diff --git a/fnet/src/vespa/fnet/transport_thread.cpp b/fnet/src/vespa/fnet/transport_thread.cpp index 34ab9091072..b0388bdc140 100644 --- a/fnet/src/vespa/fnet/transport_thread.cpp +++ b/fnet/src/vespa/fnet/transport_thread.cpp @@ -31,15 +31,6 @@ struct Sync : public FNET_IExecutable } // namespace<unnamed> -#ifndef IAM_DOXYGEN -void -FNET_TransportThread::StatsTask::PerformTask() -{ - _transport->UpdateStats(); - Schedule(5.0); -} -#endif - void FNET_TransportThread::AddComponent(FNET_IOComponent *comp) { @@ -160,22 +151,6 @@ FNET_TransportThread::DiscardEvent(FNET_ControlPacket *cpacket, } -void -FNET_TransportThread::UpdateStats() -{ - _now.SetNow(); // trade some overhead for better stats - double ms = _now.MilliSecs() - _statTime.MilliSecs(); - _statTime = _now; - { - std::lock_guard<std::mutex> guard(_lock); - _stats.Update(&_counters, ms / 1000.0); - } - _counters.Clear(); - - if (_config._logStats) - _stats.Log(); -} - extern "C" { static void pipehandler(int) @@ -203,10 +178,6 @@ FNET_TransportThread::FNET_TransportThread(FNET_Transport &owner_in) _startTime(), _now(), _scheduler(&_now), - _counters(), - _stats(), - _statsTask(&_scheduler, this), - _statTime(), _config(), _componentsHead(nullptr), _timeOutHead(nullptr), @@ -424,8 +395,6 @@ FNET_TransportThread::InitEventLoop() } _now.SetNow(); _startTime = _now; - _statTime = _now; - _statsTask.Schedule(5.0); return true; } @@ -435,7 +404,7 @@ FNET_TransportThread::handle_wakeup() { { std::lock_guard<std::mutex> guard(_lock); - CountEvent(_queue.FlushPackets_NoLock(&_myQueue)); + _queue.FlushPackets_NoLock(&_myQueue); } FNET_Context context; @@ -534,7 +503,6 @@ FNET_TransportThread::EventLoopIteration() // obtain I/O events _selector.poll(msTimeout); - CountEventLoop(); // sample current time (performed once per event loop iteration) _now.SetNow(); @@ -548,7 +516,6 @@ FNET_TransportThread::EventLoopIteration() #endif // handle wakeup and io-events - CountIOEvent(_selector.num_events()); _selector.dispatch(*this); // handle IOC time-outs @@ -579,9 +546,6 @@ FNET_TransportThread::EventLoopIteration() if (_finished) return false; - // unschedule statistics task - _statsTask.Kill(); - // flush event queue { std::lock_guard<std::mutex> guard(_lock); diff --git a/fnet/src/vespa/fnet/transport_thread.h b/fnet/src/vespa/fnet/transport_thread.h index 3e5fe49e73a..1b8d1fa4eeb 100644 --- a/fnet/src/vespa/fnet/transport_thread.h +++ b/fnet/src/vespa/fnet/transport_thread.h @@ -6,7 +6,6 @@ #include "config.h" #include "task.h" #include "packetqueue.h" -#include "stats.h" #include <vespa/fastos/thread.h> #include <vespa/fastos/time.h> #include <vespa/vespalib/net/socket_handle.h> @@ -31,31 +30,11 @@ class FNET_TransportThread : public FastOS_Runnable public: using Selector = vespalib::Selector<FNET_IOComponent>; -#ifndef IAM_DOXYGEN - class StatsTask : public FNET_Task - { - private: - FNET_TransportThread *_transport; - StatsTask(const StatsTask &); - StatsTask &operator=(const StatsTask &); - public: - StatsTask(FNET_Scheduler *scheduler, - FNET_TransportThread *transport) : FNET_Task(scheduler), - _transport(transport) {} - void PerformTask() override; - }; - friend class FNET_TransportThread::StatsTask; -#endif // DOXYGEN - private: FNET_Transport &_owner; // owning transport layer FastOS_Time _startTime; // when event loop started FastOS_Time _now; // current time sampler FNET_Scheduler _scheduler; // transport thread scheduler - FNET_StatCounters _counters; // stat counters - FNET_Stats _stats; // current stats - StatsTask _statsTask; // stats task - FastOS_Time _statTime; // last stat update FNET_Config _config; // FNET configuration [static] FNET_IOComponent *_componentsHead; // I/O component list head FNET_IOComponent *_timeOutHead; // first IOC in list to time out @@ -156,49 +135,6 @@ private: /** - * Update internal FNET statistics. This method is called regularly - * by the statistics update task. - **/ - void UpdateStats(); - - - /** - * Obtain a reference to the stat counters used by this transport - * object. - * - * @return stat counters for this transport object. - **/ - FNET_StatCounters *GetStatCounters() { return &_counters; } - - - /** - * Count event loop iteration(s). - * - * @param cnt event loop iterations (default is 1). - **/ - void CountEventLoop(uint32_t cnt = 1) - { _counters.CountEventLoop(cnt); } - - - /** - * Count internal event(s). - * - * @param cnt number of internal events. - **/ - void CountEvent(uint32_t cnt) - { _counters.CountEvent(cnt); } - - - /** - * Count IO events. - * - * @param cnt number of IO events. - **/ - void CountIOEvent(uint32_t cnt) - { _counters.CountIOEvent(cnt); } - - - /** * Obtain a reference to the object holding the configuration for * this transport object. * @@ -355,15 +291,6 @@ public: /** - * Enable or disable logging of FNET statistics. This feature is - * disabled by default. - * - * @param logStats true if stats should be logged. - **/ - void SetLogStats(bool logStats) { _config._logStats = logStats; } - - - /** * Add an I/O component to the working set of this transport * object. Note that the actual work is performed by the transport * thread. This method simply posts an event on the transport thread diff --git a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilterTest.java b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilterTest.java index be5ab9c1d77..fdab450b435 100644 --- a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilterTest.java +++ b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilterTest.java @@ -10,9 +10,9 @@ import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzPrincipal; import com.yahoo.vespa.athenz.api.AthenzUser; import com.yahoo.vespa.athenz.api.NToken; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateBuilder; import com.yahoo.vespa.athenz.utils.ntoken.NTokenValidator; import org.junit.Before; import org.junit.Test; @@ -22,6 +22,7 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.UncheckedIOException; +import java.math.BigInteger; import java.security.KeyPair; import java.security.cert.X509Certificate; import java.time.Duration; @@ -30,7 +31,8 @@ import java.util.Objects; import java.util.Set; import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED; -import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_ECDSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_RSA; import static java.util.Collections.emptyList; import static java.util.Collections.singleton; import static java.util.Collections.singletonList; @@ -189,11 +191,11 @@ public class AthenzPrincipalFilterTest { } private static X509Certificate createSelfSignedCertificate(AthenzIdentity identity) { - KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 512); + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); X500Principal x500Name = new X500Principal("CN="+ identity.getFullName()); Instant now = Instant.now(); return X509CertificateBuilder - .fromKeypair(keyPair, x500Name, now, now.plus(Duration.ofDays(30)), SHA256_WITH_RSA, 1) + .fromKeypair(keyPair, x500Name, now, now.plus(Duration.ofDays(30)), SHA256_WITH_ECDSA, BigInteger.ONE) .build(); } diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactory.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactory.java index 6e3b6a65c51..8a829d33c1b 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactory.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/ConnectorFactory.java @@ -24,6 +24,7 @@ import org.eclipse.jetty.server.SslConnectionFactory; import org.eclipse.jetty.util.ssl.SslContextFactory; import java.nio.channels.ServerSocketChannel; +import java.util.Arrays; import java.util.List; import java.util.function.BiConsumer; import java.util.function.Function; @@ -65,25 +66,11 @@ public class ConnectorFactory { connector.setName(connectorConfig.name()); connector.setAcceptQueueSize(connectorConfig.acceptQueueSize()); connector.setReuseAddress(connectorConfig.reuseAddress()); - double soLingerTimeSeconds = connectorConfig.soLingerTime(); - if (soLingerTimeSeconds == -1) { - setSoLingerTime(connector, -1); - } else { - setSoLingerTime(connector, (int)(soLingerTimeSeconds * 1000.0)); - } connector.setIdleTimeout((long)(connectorConfig.idleTimeout() * 1000.0)); connector.setStopTimeout((long)(connectorConfig.stopTimeout() * 1000.0)); return connector; } - @SuppressWarnings("deprecation") - private static void setSoLingerTime(ServerConnector connector, int milliseconds) { - // TODO: Don't use deprecated methods. Deprecate soLingerTime from connector config - // Jetty says: "don't use as socket close linger time has undefined behavior for non-blocking sockets" - // Jetty implementation is now a noop: https://github.com/eclipse/jetty.project/issues/2468, http://mail.openjdk.java.net/pipermail/nio-dev/2018-June/005195.html - connector.setSoLingerTime(milliseconds); - } - private HttpConnectionFactory newHttpConnectionFactory() { HttpConfiguration httpConfig = new HttpConfiguration(); httpConfig.setSendDateHeader(true); @@ -120,6 +107,13 @@ public class ConnectorFactory { factory.setSecureRandomAlgorithm(sslConfig.prng()); } + // NOTE: ^TLS_RSA_.*$ ciphers are disabled by default in Jetty 9.4.12+ (https://github.com/eclipse/jetty.project/issues/2807) + // JDisc will allow these ciphers by default to support older clients (e.g. Java 8u60 and curl 7.29.0) + String[] excludedCiphersWithoutTlsRsaExclusion = Arrays.stream(factory.getExcludeCipherSuites()) + .filter(cipher -> !cipher.equals("^TLS_RSA_.*$")) + .toArray(String[]::new); + factory.setExcludeCipherSuites(excludedCiphersWithoutTlsRsaExclusion); + setStringArrayParameter( factory, sslConfig.excludeProtocol(), ExcludeProtocol::name, SslContextFactory::setExcludeProtocols); setStringArrayParameter( diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java index 3a121e8b1ed..1e92fbef967 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java @@ -41,7 +41,7 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G } private static final String[] HTTP_RESPONSE_GROUPS = { Metrics.RESPONSES_1XX, Metrics.RESPONSES_2XX, Metrics.RESPONSES_3XX, - Metrics.RESPONSES_4XX, Metrics.RESPONSES_5XX }; + Metrics.RESPONSES_4XX, Metrics.RESPONSES_5XX, Metrics.RESPONSES_401, Metrics.RESPONSES_403}; private final AtomicLong inFlight = new AtomicLong(); private final LongAdder statistics[][]; @@ -112,6 +112,9 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G if (group >= 0) { HttpMethod method = getMethod(request); statistics[method.ordinal()][group].increment(); + if (group == 5 || group == 6) { // if 401/403, also increment 4xx + statistics[method.ordinal()][3].increment(); + } } long live = inFlight.decrementAndGet(); @@ -127,15 +130,19 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G } private int groupIndex(Request request) { - if (request.isHandled()) { - int index = (request.getResponse().getStatus() / 100) - 1; // 1xx = 0, 2xx = 1 etc. - if (index < 0 || index > statistics.length) { - return -1; - } else { - return index; - } + int index = request.getResponse().getStatus(); + if (index == 401) { + return 5; + } + if (index == 403) { + return 6; + } + + index = index / 100 - 1; // 1xx = 0, 2xx = 1 etc. + if (index < 0 || index >= statistics[0].length) { + return -1; } else { - return 3; // 4xx + return index; } } @@ -203,4 +210,10 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G } return shutdownCb; } + + @Override + public boolean isShutdown() { + FutureCallback futureCallback = shutdown.get(); + return futureCallback != null && futureCallback.isDone(); + } } diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java index 70d266fdfa5..8074af7f64f 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java @@ -108,6 +108,8 @@ public class JettyHttpServer extends AbstractServerProvider { String RESPONSES_3XX = "http.status.3xx"; String RESPONSES_4XX = "http.status.4xx"; String RESPONSES_5XX = "http.status.5xx"; + String RESPONSES_401 = "http.status.401"; + String RESPONSES_403 = "http.status.403"; String STARTED_MILLIS = "serverStartedMillis"; @Deprecated String MANHATTAN_STARTED_MILLIS = "proc.uptime"; diff --git a/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.connector.def b/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.connector.def index f0673e240c7..9ae4713c633 100644 --- a/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.connector.def +++ b/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.connector.def @@ -25,8 +25,8 @@ acceptQueueSize int default=0 # Whether the server socket reuses addresses. reuseAddress bool default=true -# TODO Vespa 7: Remove soLingerTime - Jetty no longer support it -# DEPRECATED The linger time in seconds. Use -1.0 to disable. +# TODO Vespa 7: Remove soLingerTime - Jetty no longer support it. +# DEPRECATED No longer in use soLingerTime double default=-1.0 # The maximum idle time for a connection, which roughly translates to the Socket.setSoTimeout(int). diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java index e3d70fb5bd6..3c23a2b0937 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java @@ -65,6 +65,19 @@ public class HttpResponseStatisticsCollectorTest { } @Test + public void statistics_include_grouped_and_single_statuscodes() throws Exception { + testRequest(401, "GET"); + testRequest(404, "GET"); + testRequest(403, "GET"); + + Map<String, Map<String, Long>> stats = collector.takeStatisticsByMethod(); + assertThat(stats.get("GET").get(Metrics.RESPONSES_4XX), equalTo(3L)); + assertThat(stats.get("GET").get(Metrics.RESPONSES_401), equalTo(1L)); + assertThat(stats.get("GET").get(Metrics.RESPONSES_403), equalTo(1L)); + + } + + @Test public void retrieving_statistics_resets_the_counters() throws Exception { testRequest(200, "GET"); testRequest(200, "GET"); diff --git a/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java b/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java index d0f5de54b4f..bc3d1edda7c 100644 --- a/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java +++ b/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java @@ -32,7 +32,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider private final ResourceReference sessionReference; @Inject - public MbusServer(final CurrentContainer container, final ServerSession session) { + public MbusServer(CurrentContainer container, ServerSession session) { this.container = container; this.session = session; uri = URI.create("mbus://localhost/" + session.name()); @@ -60,7 +60,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider } @Override - public void handleMessage(final Message msg) { + public void handleMessage(Message msg) { if (!running.get()) { dispatchErrorReply(msg, ErrorCode.SESSION_BUSY, "Session temporarily closed."); return; @@ -73,7 +73,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider try { request = new MbusRequest(container, uri, msg); content = request.connect(new ServerResponseHandler(msg)); - } catch (final RuntimeException e) { + } catch (RuntimeException e) { dispatchErrorReply(msg, ErrorCode.APP_FATAL_ERROR, e.toString()); } finally { if (request != null) { @@ -89,8 +89,8 @@ public final class MbusServer extends AbstractResource implements ServerProvider return session.connectionSpec(); } - private void dispatchErrorReply(final Message msg, final int errCode, final String errMsg) { - final Reply reply = new EmptyReply(); + private void dispatchErrorReply(Message msg, int errCode, String errMsg) { + Reply reply = new EmptyReply(); reply.swapState(msg); reply.addError(new Error(errCode, errMsg)); session.sendReply(reply); @@ -100,20 +100,20 @@ public final class MbusServer extends AbstractResource implements ServerProvider final Message msg; - ServerResponseHandler(final Message msg) { + ServerResponseHandler(Message msg) { this.msg = msg; } @Override - public ContentChannel handleResponse(final Response response) { - final Reply reply; + public ContentChannel handleResponse(Response response) { + Reply reply; if (response instanceof MbusResponse) { reply = ((MbusResponse)response).getReply(); } else { reply = new EmptyReply(); reply.swapState(msg); } - final Error err = StatusCodes.toMbusError(response.getStatus()); + Error err = StatusCodes.toMbusError(response.getStatus()); if (err != null) { if (err.isFatal()) { if (!reply.hasFatalErrors()) { diff --git a/jrt/pom.xml b/jrt/pom.xml index cf3da2ab7ce..84578f9e04d 100644 --- a/jrt/pom.xml +++ b/jrt/pom.xml @@ -19,6 +19,11 @@ <scope>test</scope> </dependency> <dependency> + <groupId>org.bouncycastle</groupId> + <artifactId>bcpkix-jdk15on</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>annotations</artifactId> <version>${project.version}</version> diff --git a/jrt/src/com/yahoo/jrt/CryptoEngine.java b/jrt/src/com/yahoo/jrt/CryptoEngine.java index 9852d5a88a6..2ef936ec7ed 100644 --- a/jrt/src/com/yahoo/jrt/CryptoEngine.java +++ b/jrt/src/com/yahoo/jrt/CryptoEngine.java @@ -2,7 +2,10 @@ package com.yahoo.jrt; +import com.yahoo.security.tls.TransportSecurityOptions; + import java.nio.channels.SocketChannel; +import java.nio.file.Paths; /** @@ -13,5 +16,12 @@ import java.nio.channels.SocketChannel; **/ public interface CryptoEngine { public CryptoSocket createCryptoSocket(SocketChannel channel, boolean isServer); - static public CryptoEngine createDefault() { return new NullCryptoEngine(); } + static public CryptoEngine createDefault() { // TODO Move this logic to a dedicated factory class + String tlsConfigParameter = System.getenv("VESPA_TLS_CONFIG_FILE"); + if (tlsConfigParameter != null && !tlsConfigParameter.isEmpty()) { + return new TlsCryptoEngine(TransportSecurityOptions.fromJsonFile(Paths.get(tlsConfigParameter))); + } else { + return new NullCryptoEngine(); + } + } } diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java new file mode 100644 index 00000000000..b3daf5c296d --- /dev/null +++ b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java @@ -0,0 +1,48 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +import com.yahoo.security.SslContextBuilder; +import com.yahoo.security.X509CertificateUtils; +import com.yahoo.security.tls.TransportSecurityOptions; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.channels.SocketChannel; +import java.nio.file.Files; +import java.security.cert.X509Certificate; +import java.util.List; + +/** + * A {@link CryptoSocket} that creates {@link TlsCryptoSocket} instances. + * + * @author bjorncs + */ +public class TlsCryptoEngine implements CryptoEngine { + + private final SSLContext sslContext; + + public TlsCryptoEngine(SSLContext sslContext) { + this.sslContext = sslContext; + } + + public TlsCryptoEngine(TransportSecurityOptions options) { + this(createSslContext(options)); + } + + @Override + public TlsCryptoSocket createCryptoSocket(SocketChannel channel, boolean isServer) { + SSLEngine sslEngine = sslContext.createSSLEngine(); + sslEngine.setNeedClientAuth(true); + sslEngine.setUseClientMode(!isServer); + return new TlsCryptoSocket(channel, sslEngine); + } + + private static SSLContext createSslContext(TransportSecurityOptions options) { + return new SslContextBuilder() + .withTrustStore(options.getCaCertificatesFile()) + .withKeyStore(options.getPrivateKeyFile(), options.getCertificatesFile()) + .build(); + } +} diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java new file mode 100644 index 00000000000..3db54811f9e --- /dev/null +++ b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java @@ -0,0 +1,253 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SocketChannel; +import java.util.logging.Logger; + +import static javax.net.ssl.SSLEngineResult.*; + +/** + * A {@link CryptoSocket} using TLS ({@link SSLEngine}) + * + * @author bjorncs + */ +public class TlsCryptoSocket implements CryptoSocket { + + private static final ByteBuffer NULL_BUFFER = ByteBuffer.allocate(0); + + private static final Logger log = Logger.getLogger(TlsCryptoSocket.class.getName()); + + private enum HandshakeState { NOT_STARTED, NEED_READ, NEED_WRITE, COMPLETED } + + private final SocketChannel channel; + private final SSLEngine sslEngine; + private final Buffer wrapBuffer; + private final Buffer unwrapBuffer; + private int sessionPacketBufferSize; + private int sessionApplicationBufferSize; + private ByteBuffer handshakeDummyBuffer; + private HandshakeState handshakeState; + + public TlsCryptoSocket(SocketChannel channel, SSLEngine sslEngine) { + this.channel = channel; + this.sslEngine = sslEngine; + SSLSession nullSession = sslEngine.getSession(); + this.wrapBuffer = new Buffer(nullSession.getPacketBufferSize() * 2); + this.unwrapBuffer = new Buffer(nullSession.getPacketBufferSize() * 2); + // Note: Dummy buffer as unwrap requires a full size application buffer even though no application data is unwrapped + this.handshakeDummyBuffer = ByteBuffer.allocate(nullSession.getApplicationBufferSize()); + this.handshakeState = HandshakeState.NOT_STARTED; + } + + @Override + public SocketChannel channel() { + return channel; + } + + @Override + public HandshakeResult handshake() throws IOException { + HandshakeState newHandshakeState = processHandshakeState(this.handshakeState); + log.fine(() -> String.format("Handshake state '%s -> %s'", this.handshakeState, newHandshakeState)); + this.handshakeState = newHandshakeState; + return toHandshakeResult(newHandshakeState); + } + + private HandshakeState processHandshakeState(HandshakeState state) throws IOException { + switch (state) { + case NOT_STARTED: + sslEngine.beginHandshake(); + break; + case NEED_WRITE: + channelWrite(); + break; + case NEED_READ: + channelRead(); + break; + case COMPLETED: + return HandshakeState.COMPLETED; + default: + throw unhandledStateException(state); + } + + while (true) { + switch (sslEngine.getHandshakeStatus()) { + case NOT_HANDSHAKING: + if (wrapBuffer.bytes() > 0) return HandshakeState.NEED_WRITE; + sslEngine.setEnableSessionCreation(false); // disable renegotiation + handshakeDummyBuffer = null; + SSLSession session = sslEngine.getSession(); + sessionApplicationBufferSize = session.getApplicationBufferSize(); + sessionPacketBufferSize = session.getPacketBufferSize(); + return HandshakeState.COMPLETED; + case NEED_TASK: + sslEngine.getDelegatedTask().run(); + break; + case NEED_UNWRAP: + if (wrapBuffer.bytes() > 0) return HandshakeState.NEED_WRITE; + if (!handshakeUnwrap()) return HandshakeState.NEED_READ; + break; + case NEED_WRAP: + if (!handshakeWrap()) return HandshakeState.NEED_WRITE; + break; + default: + throw new IllegalStateException("Unexpected handshake status: " + sslEngine.getHandshakeStatus()); + } + } + } + + private static HandshakeResult toHandshakeResult(HandshakeState state) { + switch (state) { + case NEED_READ: + return HandshakeResult.NEED_READ; + case NEED_WRITE: + return HandshakeResult.NEED_WRITE; + case COMPLETED: + return HandshakeResult.DONE; + default: + throw unhandledStateException(state); + } + } + + @Override + public int getMinimumReadBufferSize() { + return sessionApplicationBufferSize; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + verifyHandshakeCompleted(); + int bytesUnwrapped = drain(dst); + if (bytesUnwrapped > 0) return bytesUnwrapped; + + int bytesRead = channelRead(); + if (bytesRead == 0) return 0; + return drain(dst); + } + + @Override + public int drain(ByteBuffer dst) throws IOException { + verifyHandshakeCompleted(); + int totalBytesUnwrapped = 0; + int bytesUnwrapped; + do { + bytesUnwrapped = applicationDataUnwrap(dst); + totalBytesUnwrapped += bytesUnwrapped; + } while (bytesUnwrapped > 0); + return totalBytesUnwrapped; + } + + @Override + public int write(ByteBuffer src) throws IOException { + if (flush() == FlushResult.NEED_WRITE) return 0; + int totalBytesWrapped = 0; + int bytesWrapped; + do { + bytesWrapped = applicationDataWrap(src); + totalBytesWrapped += bytesWrapped; + } while (bytesWrapped > 0 && wrapBuffer.bytes() < sessionPacketBufferSize); + return totalBytesWrapped; + } + + @Override + public FlushResult flush() throws IOException { + channelWrite(); + return wrapBuffer.bytes() > 0 ? FlushResult.NEED_WRITE : FlushResult.DONE; + } + + private boolean handshakeWrap() throws IOException { + SSLEngineResult result = sslEngineWrap(NULL_BUFFER); + switch (result.getStatus()) { + case OK: + return true; + case BUFFER_OVERFLOW: + return false; + default: + throw unexpectedStatusException(result.getStatus()); + } + } + + private int applicationDataWrap(ByteBuffer src) throws IOException { + SSLEngineResult result = sslEngineWrap(src); + if (result.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING) throw new SSLException("Renegotiation detected"); + switch (result.getStatus()) { + case OK: + return result.bytesConsumed(); + case BUFFER_OVERFLOW: + return 0; + default: + throw unexpectedStatusException(result.getStatus()); + } + } + + private SSLEngineResult sslEngineWrap(ByteBuffer src) throws IOException { + SSLEngineResult result = sslEngine.wrap(src, wrapBuffer.getWritable(sessionPacketBufferSize)); + if (result.getStatus() == Status.CLOSED) throw new ClosedChannelException(); + return result; + } + + private boolean handshakeUnwrap() throws IOException { + SSLEngineResult result = sslEngineUnwrap(handshakeDummyBuffer); + switch (result.getStatus()) { + case OK: + if (result.bytesProduced() > 0) throw new SSLException("Got application data in handshake unwrap"); + return true; + case BUFFER_UNDERFLOW: + return false; + default: + throw unexpectedStatusException(result.getStatus()); + } + } + + private int applicationDataUnwrap(ByteBuffer dst) throws IOException { + SSLEngineResult result = sslEngineUnwrap(dst); + if (result.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING) throw new SSLException("Renegotiation detected"); + switch (result.getStatus()) { + case OK: + return result.bytesProduced(); + case BUFFER_OVERFLOW: + case BUFFER_UNDERFLOW: + return 0; + default: + throw unexpectedStatusException(result.getStatus()); + } + } + + private SSLEngineResult sslEngineUnwrap(ByteBuffer dst) throws IOException { + SSLEngineResult result = sslEngine.unwrap(unwrapBuffer.getReadable(), dst); + if (result.getStatus() == Status.CLOSED) throw new ClosedChannelException(); + return result; + } + + // returns number of bytes read + private int channelRead() throws IOException { + int read = channel.read(unwrapBuffer.getWritable(sessionPacketBufferSize)); + if (read == -1) throw new ClosedChannelException(); + return read; + } + + // returns number of bytes written + private int channelWrite() throws IOException { + return channel.write(wrapBuffer.getReadable()); + } + + private static IllegalStateException unhandledStateException(HandshakeState state) { + return new IllegalStateException("Unhandled state: " + state); + } + + private static IllegalStateException unexpectedStatusException(Status status) { + return new IllegalStateException("Unexpected status: " + status); + } + + private void verifyHandshakeCompleted() throws SSLException { + if (handshakeState != HandshakeState.COMPLETED) + throw new SSLException("Handshake not completed: handshakeState=" + handshakeState); + } + +} diff --git a/jrt/tests/com/yahoo/jrt/CryptoUtils.java b/jrt/tests/com/yahoo/jrt/CryptoUtils.java new file mode 100644 index 00000000000..c3128e09bd3 --- /dev/null +++ b/jrt/tests/com/yahoo/jrt/CryptoUtils.java @@ -0,0 +1,43 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SslContextBuilder; +import com.yahoo.security.X509CertificateBuilder; + +import javax.net.ssl.SSLContext; +import javax.security.auth.x500.X500Principal; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.cert.X509Certificate; +import java.time.Instant; + +import static com.yahoo.security.KeyAlgorithm.RSA; +import static com.yahoo.security.KeyStoreType.PKCS12; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_RSA; +import static com.yahoo.security.X509CertificateBuilder.generateRandomSerialNumber; +import static java.time.Instant.EPOCH; +import static java.time.temporal.ChronoUnit.DAYS; + +/** + * @author bjorncs + */ +class CryptoUtils { + static SSLContext createTestSslContext() { + KeyPair keyPair = KeyUtils.generateKeypair(RSA); + + X509Certificate certificate = X509CertificateBuilder + .fromKeypair(keyPair, new X500Principal("CN=dummy"), EPOCH, Instant.now().plus(1, DAYS), SHA256_WITH_RSA, generateRandomSerialNumber()) + .build(); + + KeyStore trustStore = KeyStoreBuilder.withType(PKCS12) + .withCertificateEntry("self-signed", certificate) + .build(); + + return new SslContextBuilder() + .withTrustStore(trustStore) + .withKeyStore(keyPair.getPrivate(), certificate) + .build(); + } +} diff --git a/jrt/tests/com/yahoo/jrt/EchoTest.java b/jrt/tests/com/yahoo/jrt/EchoTest.java index 0523241354a..a91ac117f41 100644 --- a/jrt/tests/com/yahoo/jrt/EchoTest.java +++ b/jrt/tests/com/yahoo/jrt/EchoTest.java @@ -5,11 +5,11 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; -import org.junit.runners.Parameterized; - +import static com.yahoo.jrt.CryptoUtils.createTestSslContext; import static org.junit.Assert.assertTrue; @RunWith(Parameterized.class) @@ -22,8 +22,8 @@ public class EchoTest { Values refValues; @Parameter public CryptoEngine crypto; - @Parameters public static Object[] engines() { - return new Object[] { CryptoEngine.createDefault(), new XorCryptoEngine() }; + @Parameters(name = "{0}") public static Object[] engines() { + return new Object[] { CryptoEngine.createDefault(), new XorCryptoEngine(), new TlsCryptoEngine(createTestSslContext()) }; } @Before diff --git a/jrt/tests/com/yahoo/jrt/SessionTest.java b/jrt/tests/com/yahoo/jrt/SessionTest.java index 2f1a64538de..63d14601b6e 100644 --- a/jrt/tests/com/yahoo/jrt/SessionTest.java +++ b/jrt/tests/com/yahoo/jrt/SessionTest.java @@ -5,10 +5,11 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; -import org.junit.runners.Parameterized; +import static com.yahoo.jrt.CryptoUtils.createTestSslContext; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -17,8 +18,8 @@ import static org.junit.Assert.assertTrue; public class SessionTest implements SessionHandler { @Parameter public CryptoEngine crypto; - @Parameters public static Object[] engines() { - return new Object[] { CryptoEngine.createDefault(), new XorCryptoEngine() }; + @Parameters(name = "{0}") public static Object[] engines() { + return new Object[] { CryptoEngine.createDefault(), new XorCryptoEngine(), new TlsCryptoEngine(createTestSslContext()) }; } private static class Session { diff --git a/jrt_test/src/jrt-test/simpleserver/simpleserver.cpp b/jrt_test/src/jrt-test/simpleserver/simpleserver.cpp index a0781ee4720..89d8cd881a8 100644 --- a/jrt_test/src/jrt-test/simpleserver/simpleserver.cpp +++ b/jrt_test/src/jrt-test/simpleserver/simpleserver.cpp @@ -14,19 +14,19 @@ public: { FRT_ReflectionBuilder rb(s); //--------------------------------------------------------------------- - rb.DefineMethod("inc", "i", "i", true, + rb.DefineMethod("inc", "i", "i", FRT_METHOD(Server::rpc_inc), this); rb.MethodDesc("Increase an integer value"); rb.ParamDesc("value", "initial value"); rb.ReturnDesc("result", "value + 1"); //--------------------------------------------------------------------- - rb.DefineMethod("blob", "x", "x", true, + rb.DefineMethod("blob", "x", "x", FRT_METHOD(Server::rpc_blob), this); rb.MethodDesc("Send a copy of a blob back to the client"); rb.ParamDesc("blob", "the original blob"); rb.ReturnDesc("blob", "a copy of the original blob"); //--------------------------------------------------------------------- - rb.DefineMethod("test", "iib", "i", true, + rb.DefineMethod("test", "iib", "i", FRT_METHOD(Server::rpc_test), this); rb.MethodDesc("Magic test method"); rb.ParamDesc("value", "the value"); diff --git a/jrt_test/src/tests/mockup-invoke/mockup-server.cpp b/jrt_test/src/tests/mockup-invoke/mockup-server.cpp index 32c9bcc6c21..8456bee1e41 100644 --- a/jrt_test/src/tests/mockup-invoke/mockup-server.cpp +++ b/jrt_test/src/tests/mockup-invoke/mockup-server.cpp @@ -14,7 +14,7 @@ public: { FRT_ReflectionBuilder rb(s); //------------------------------------------------------------------- - rb.DefineMethod("concat", "ss", "s", true, + rb.DefineMethod("concat", "ss", "s", FRT_METHOD(MockupServer::RPC_concat), this); rb.MethodDesc("Concatenate two strings"); rb.ParamDesc("string1", "a string"); diff --git a/linguistics/src/main/java/com/yahoo/language/detect/Detection.java b/linguistics/src/main/java/com/yahoo/language/detect/Detection.java index c08bdc14cfb..4b816335154 100644 --- a/linguistics/src/main/java/com/yahoo/language/detect/Detection.java +++ b/linguistics/src/main/java/com/yahoo/language/detect/Detection.java @@ -7,7 +7,7 @@ import java.nio.charset.Charset; import java.nio.charset.UnsupportedCharsetException; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class Detection { diff --git a/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java b/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java index 12de309a2d3..7451a7f2c9c 100644 --- a/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java +++ b/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java @@ -4,8 +4,10 @@ import com.yahoo.language.process.Tokenizer; import com.yahoo.language.simple.SimpleLinguistics; public class OpenNlpLinguistics extends SimpleLinguistics { + @Override public Tokenizer getTokenizer() { return new OpenNlpTokenizer(getNormalizer(), getTransformer()); } + } diff --git a/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java b/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java index 2b31f95675b..0503ac61df1 100644 --- a/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java +++ b/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java @@ -37,36 +37,49 @@ import java.util.Locale; * @author bjorncs */ public class SimpleDetector implements Detector { - static private TextObjectFactory textObjectFactory; - static private LanguageDetector languageDetector; - - static { - // origin: https://github.com/optimaize/language-detector - //load all languages: - List<LanguageProfile> languageProfiles; - try { - languageProfiles = new LanguageProfileReader().readAllBuiltIn(); - } catch (IOException e) { - throw new RuntimeException(e); - } - //build language detector: - languageDetector = LanguageDetectorBuilder.create(NgramExtractors.standard()) - .withProfiles(languageProfiles) - .build(); + static private Object initGuard = new Object(); + static private TextObjectFactory textObjectFactory = null; + static private LanguageDetector languageDetector = null; + + static private void initOptimaize (boolean useOptimaize) { + if (!useOptimaize) return; + synchronized (initGuard) { + if ((textObjectFactory != null) && (languageDetector != null)) return; + + // origin: https://github.com/optimaize/language-detector + //load all languages: + List<LanguageProfile> languageProfiles; + try { + languageProfiles = new LanguageProfileReader().readAllBuiltIn(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + //build language detector: + languageDetector = LanguageDetectorBuilder.create(NgramExtractors.standard()) + .withProfiles(languageProfiles) + .build(); - //create a text object factory - textObjectFactory = CommonTextObjectFactories.forDetectingOnLargeText(); + //create a text object factory + textObjectFactory = CommonTextObjectFactories.forDetectingOnLargeText(); + } } private final boolean enableOptimaize; + private SimpleDetector(boolean enableOptimaize) { + initOptimaize(enableOptimaize); + this.enableOptimaize = enableOptimaize; + + } + public SimpleDetector() { - this.enableOptimaize = true; + this(true); } public SimpleDetector(SimpleLinguisticsConfig.Detector detector) { - this.enableOptimaize = detector.enableOptimaize(); + this(detector.enableOptimaize()); } @Override diff --git a/logd/CMakeLists.txt b/logd/CMakeLists.txt index 3eeeb7adb66..6a8296564a3 100644 --- a/logd/CMakeLists.txt +++ b/logd/CMakeLists.txt @@ -18,3 +18,5 @@ vespa_define_module( src/tests/info src/tests/rotate ) + +vespa_install_script(src/apps/retention/retention-enforcer.sh vespa-retention-enforcer sbin) diff --git a/logd/src/apps/retention/retention-enforcer.sh b/logd/src/apps/retention/retention-enforcer.sh new file mode 100755 index 00000000000..7ab1b27d71a --- /dev/null +++ b/logd/src/apps/retention/retention-enforcer.sh @@ -0,0 +1,137 @@ +#!/bin/sh + +# daemon that collects old log files. +# global settings: + +DBGF=logs/vespa/debug.retention-enforcer +DBDIR=var/db/vespa/logfiledb +PIDF=$DBDIR/retention-enforcer.pid +RETAIN_DAYS=31 + +# this depends on components adding their log files +# to a "database" in DBDIR named "logfiles.TTTTT" where +# TTTTT is a timestamp in format (seconds/100000). +# The "database" holds lines with format "timestamp /path/to/logfile" +# where "timestamp" is just seconds since epoch. + +prereq_dir() { + if [ -d $1 ] && [ -w $1 ]; then + : + else + echo "$0: missing directory '$1' in '`pwd`'" >&2 + exit 1 + fi +} + +check_prereqs() { + prereq_dir var/db/vespa + prereq_dir logs/vespa +} + +ensure_dir () { + if [ -d $1 ] && [ -w $1 ]; then + return 0 + fi + echo "Creating directory '$1' in '`pwd`'" + mkdir -p $1 || exit 1 +} + +prepare_stuff() { + check_prereqs + exec > $DBGF.$$.log 2>&1 + ensure_dir $DBDIR +} + +bad_timestamp() { + now=$(date +%s) + if [ "$1" ] && [ "$1" -ge 1514764800 ] && [ "$1" -le $now ]; then + # sane timestamp: + return 1 + fi + # bad timestamp: + return 0 +} + +mark_pid() { + echo $$ > $PIDF.$$.tmp + mv $PIDF.$$.tmp $PIDF || exit 1 +} + +check_pidfile() { + read pid < $PIDF + [ "$pid" = $$ ] && return 0 + if [ "$pid" ] && [ $pid -gt $$ ]; then + sleep 30 + read pid_again < $PIDF + if [ "$pid_again" != "$pid" ]; then return 1; fi + ps -p $pid >/dev/null 2>&1 || return 1 + proc=$(ps -p $pid 2>&1) + case $proc in *retention*) ;; *) return 1;; esac + echo "$0 [$$]: Yielding my place to pid '$pid'" + exit 1 + fi +} + +get_mod_time() { + perl -e 'print (((stat("'"$1"'"))[9]) . "\n")' +} + +maybe_collect() { + timestamp=$1 + logfilename=$2 + + if bad_timestamp "$1"; then + echo "WARNING: bad timestamp '$timestamp' for logfilename '$logfilename'" + return + fi + + add=$((86400 * $RETAIN_DAYS)) + lim1=$(($timestamp + $add)) + mod_time=$(get_mod_time "$logfilename") + lim2=$(($mod_time + $add)) + + if [ $lim1 -lt $now ] && [ $lim2 -lt $now ]; then + echo "Collect logfile '$logfilename' timestamped $timestamp modified $mod_time" + rm -f "$logfilename" + fi +} + +process_file() { + dbfile="$1" + now=$(date +%s) + found=0 + while read timestamp logfilename; do + for fn in $logfilename $logfilename.*z*; do + if [ -f "$fn" ]; then + found=1 + maybe_collect "$timestamp" "$fn" + fi + done + done < $dbfile + if [ $found = 0 ]; then + ts=${dbfile##*.}99999 + maybe_collect "$ts" "$dbfile" + fi +} + +process_all() { + for dbf in $DBDIR/logfiles.* ; do + [ -f "$dbf" ] || continue + process_file "$dbf" + done +} + +mainloop() { + while true; do + mark_pid + process_all + sleep 3600 + check_pidfile + done +} + +# MAIN: + +prepare_stuff +mainloop +exit 0 diff --git a/messagebus/src/main/java/com/yahoo/messagebus/Message.java b/messagebus/src/main/java/com/yahoo/messagebus/Message.java index 22496487f61..43f5c8d2dfd 100644 --- a/messagebus/src/main/java/com/yahoo/messagebus/Message.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/Message.java @@ -5,9 +5,9 @@ import com.yahoo.concurrent.SystemTimer; import com.yahoo.messagebus.routing.Route; /** - * <p>A message is a child of Routable, it is not a reply, and it has a sequencing identifier. Furthermore, a message + * A message is a child of Routable, it is not a reply, and it has a sequencing identifier. Furthermore, a message * contains a retry counter that holds what retry the message is currently on. See the method comment {@link #getRetry} - * for more information.</p> + * for more information. * * @author Simon Thoresen Hult */ diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java index 63514eca6dd..e21aeef1ee2 100755 --- a/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java @@ -20,15 +20,15 @@ public class Hop { private String cache = null; /** - * <p>Constructs an empty hop. You will need to add directives to the - * selector to make this usable.</p> + * Constructs an empty hop. You will need to add directives to the + * selector to make this usable. */ public Hop() { // empty } /** - * <p>Implements the copy constructor.</p> + * Implements the copy constructor. * * @param hop The hop to copy. */ @@ -38,8 +38,8 @@ public class Hop { } /** - * <p>Constructs a fully populated hop. This is package private and used by - * the {@link HopBlueprint#create()} method.</p> + * Constructs a fully populated hop. This is package private and used by + * the {@link HopBlueprint#create()} method. * * @param selector The selector to copy. * @param ignoreResult Whether or not to ignore the result of this hop. @@ -50,8 +50,8 @@ public class Hop { } /** - * <p>Parses the given string as a single hop. The {@link #toString()} - * method is compatible with this parser.</p> + * Parses the given string as a single hop. The {@link #toString()} + * method is compatible with this parser. * * @param str The string to parse. * @return A hop that corresponds to the string. @@ -65,8 +65,7 @@ public class Hop { } /** - * <p>Returns whether or not there are any directives contained in this - * hop.</p> + * Returns whether or not there are any directives contained in this hop. * * @return True if there is at least one directive. */ @@ -75,7 +74,7 @@ public class Hop { } /** - * <p>Returns the number of directives contained in this hop.</p> + * Returns the number of directives contained in this hop. * * @return The number of directives. */ @@ -84,7 +83,7 @@ public class Hop { } /** - * <p>Returns the directive at the given index.</p> + * Returns the directive at the given index. * * @param i The index of the directive to return. * @return The item. @@ -94,7 +93,7 @@ public class Hop { } /** - * <p>Adds a new directive to this hop.</p> + * Adds a new directive to this hop. * * @param directive The directive to add. * @return This, to allow chaining. @@ -106,7 +105,7 @@ public class Hop { } /** - * <p>Sets the directive at a given index.</p> + * Sets the directive at a given index. * * @param i The index at which to set the directive. * @param directive The directive to set. @@ -283,9 +282,10 @@ public class Hop { @Override public int hashCode() { - int result = selector != null ? selector.hashCode() : 0; + int result = selector.hashCode(); result = 31 * result + (ignoreResult ? 1 : 0); result = 31 * result + (cache != null ? cache.hashCode() : 0); return result; } + } diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java index 809b2da69c4..838b11e7a02 100755 --- a/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java @@ -14,13 +14,14 @@ public interface HopDirective { * @param dir The directive to compare this to. * @return True if this matches the argument. */ - public boolean matches(HopDirective dir); + boolean matches(HopDirective dir); /** * Returns a string representation of this that can be debugged but not parsed. * * @return The debug string. */ - public String toDebugString(); + String toDebugString(); + } diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java index a07c6e16100..9190b680ebf 100755 --- a/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java @@ -20,7 +20,7 @@ import java.util.List; */ public class Route { - private final List<Hop> hops = new ArrayList<Hop>(); + private final List<Hop> hops = new ArrayList<>(); private String cache = null; /** diff --git a/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp b/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp index 913338785f2..b72416f51d2 100644 --- a/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp +++ b/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp @@ -187,7 +187,7 @@ RPCNetwork::attach(INetworkOwner &owner) _sendAdapters[vespalib::Version(6, 149)] = _sendV2.get(); FRT_ReflectionBuilder builder(_orb.get()); - builder.DefineMethod("mbus.getVersion", "", "s", true, FRT_METHOD(RPCNetwork::invoke), this); + builder.DefineMethod("mbus.getVersion", "", "s", FRT_METHOD(RPCNetwork::invoke), this); builder.MethodDesc("Retrieves the message bus version."); builder.ReturnDesc("version", "The message bus version."); } diff --git a/messagebus/src/vespa/messagebus/network/rpcsendv1.cpp b/messagebus/src/vespa/messagebus/network/rpcsendv1.cpp index 6b89a278b88..376267b555c 100644 --- a/messagebus/src/vespa/messagebus/network/rpcsendv1.cpp +++ b/messagebus/src/vespa/messagebus/network/rpcsendv1.cpp @@ -35,7 +35,7 @@ RPCSendV1::getReturnSpec() const { void RPCSendV1::build(FRT_ReflectionBuilder & builder) { - builder.DefineMethod(METHOD_NAME, METHOD_PARAMS, METHOD_RETURN, true, FRT_METHOD(RPCSendV1::invoke), this); + builder.DefineMethod(METHOD_NAME, METHOD_PARAMS, METHOD_RETURN, FRT_METHOD(RPCSendV1::invoke), this); builder.MethodDesc("Send a message bus request and get a reply back."); builder.ParamDesc("version", "The version of the message."); builder.ParamDesc("route", "Names of additional hops to visit."); diff --git a/messagebus/src/vespa/messagebus/network/rpcsendv2.cpp b/messagebus/src/vespa/messagebus/network/rpcsendv2.cpp index 4c04549aee1..91a41a6a800 100644 --- a/messagebus/src/vespa/messagebus/network/rpcsendv2.cpp +++ b/messagebus/src/vespa/messagebus/network/rpcsendv2.cpp @@ -59,7 +59,7 @@ bool RPCSendV2::isCompatible(stringref method, stringref request, stringref resp void RPCSendV2::build(FRT_ReflectionBuilder & builder) { - builder.DefineMethod(METHOD_NAME, METHOD_PARAMS, METHOD_RETURN, true, FRT_METHOD(RPCSendV2::invoke), this); + builder.DefineMethod(METHOD_NAME, METHOD_PARAMS, METHOD_RETURN, FRT_METHOD(RPCSendV2::invoke), this); builder.MethodDesc("Send a message bus slime request and get a reply back."); builder.ParamDesc("header_encoding", "0=raw, 6=lz4"); builder.ParamDesc("header_decoded_size", "Uncompressed header blob size"); diff --git a/model-evaluation/OWNERS b/model-evaluation/OWNERS new file mode 100644 index 00000000000..2bd865cff34 --- /dev/null +++ b/model-evaluation/OWNERS @@ -0,0 +1,2 @@ +bratseth +lesters diff --git a/model-evaluation/README b/model-evaluation/README new file mode 100644 index 00000000000..0bf143a2804 --- /dev/null +++ b/model-evaluation/README @@ -0,0 +1,6 @@ +Provides +- an injectable component (ai.vespa.models.evaluation.ModelsEvaluator) which allows direct, stateless evaluation of + any machine learned models added to the models/ directory in any container. +- a handler (turned on with the <models-evaluation> tag in <container>) which provides the models-evaluation REST + API which provides stateless (single data point) model evaluation over HTTP(S). + diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml index edb22c1b529..0421c680edf 100644 --- a/model-evaluation/pom.xml +++ b/model-evaluation/pom.xml @@ -40,6 +40,11 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>searchcore</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>config</artifactId> <version>${project.version}</version> <scope>provided</scope> @@ -74,6 +79,9 @@ <groupId>com.yahoo.vespa</groupId> <artifactId>bundle-plugin</artifactId> <extensions>true</extensions> + <configuration> + <importPackage>com.yahoo.config</importPackage> <!-- To make DI see RankingConstantsConfig as a ConfigInstance --> + </configuration> </plugin> </plugins> </build> diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java new file mode 100644 index 00000000000..e664693ab38 --- /dev/null +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java @@ -0,0 +1,27 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.tensor.Tensor; + +/** + * A named constant loaded from a file. + * + * This is immutable. + * + * @author bratseth + */ +class Constant { + + private final String name; + private final Tensor value; + + Constant(String name, Tensor value) { + this.name = name; + this.value = value; + } + + public String name() { return name; } + + public Tensor value() { return value; } + +} diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 520986ffb77..e08b9f77d15 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -56,4 +56,6 @@ public class FunctionEvaluator { return function.getBody().evaluate(context).asTensor(); } + LazyArrayContext context() { return context; } + } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 2dcfd204077..beaa36b898f 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -16,6 +17,7 @@ import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -37,8 +39,11 @@ final class LazyArrayContext extends Context implements ContextIndex { * * @param expression the expression to create a context for */ - LazyArrayContext(RankingExpression expression, Map<FunctionReference, ExpressionFunction> functions, Model model) { - this.indexedBindings = new IndexedBindings(expression, functions, this, model); + LazyArrayContext(RankingExpression expression, + Map<FunctionReference, ExpressionFunction> functions, + List<Constant> constants, + Model model) { + this.indexedBindings = new IndexedBindings(expression, functions, constants, this, model); } /** @@ -139,8 +144,10 @@ final class LazyArrayContext extends Context implements ContextIndex { */ IndexedBindings(RankingExpression expression, Map<FunctionReference, ExpressionFunction> functions, + List<Constant> constants, LazyArrayContext owner, Model model) { + // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); extractBindTargets(expression.getRoot(), functions, bindTargets); @@ -150,9 +157,18 @@ final class LazyArrayContext extends Context implements ContextIndex { int i = 0; ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); for (String variable : bindTargets) - nameToIndexBuilder.put(variable,i++); + nameToIndexBuilder.put(variable, i++); nameToIndex = nameToIndexBuilder.build(); + + // 2. Bind the bind targets + for (Constant constant : constants) { + String constantReference = "constant(" + constant.name() + ")"; + Integer index = nameToIndex.get(constantReference); + if (index != null) + values[index] = new TensorValue(constant.value()); + } + for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { Integer index = nameToIndex.get(function.getKey().serialForm()); if (index != null) // Referenced in this, so bind it @@ -170,7 +186,7 @@ final class LazyArrayContext extends Context implements ContextIndex { extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets); } else if (isConstant(node)) { - // Ignore + bindTargets.add(node.toString()); } else if (node instanceof ReferenceNode) { bindTargets.add(node.toString()); @@ -193,7 +209,7 @@ final class LazyArrayContext extends Context implements ContextIndex { if ( ! (node instanceof ReferenceNode)) return false; ReferenceNode reference = (ReferenceNode)node; - return reference.getName().equals("value") && reference.getArguments().size() == 1; + return reference.getName().equals("constant") && reference.getArguments().size() == 1; } Value get(int index) { return values[index]; } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 95eb923786d..3fb43d73187 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -36,11 +36,15 @@ public class Model { private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); + /** Programmatically create a model containing functions without constant of function references only */ public Model(String name, Collection<ExpressionFunction> functions) { - this(name, functions, Collections.emptyMap()); + this(name, functions, Collections.emptyMap(), Collections.emptyList()); } - Model(String name, Collection<ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions) { + Model(String name, + Collection<ExpressionFunction> functions, + Map<FunctionReference, ExpressionFunction> referencedFunctions, + List<Constant> constants) { // TODO: Optimize functions this.name = name; this.functions = ImmutableList.copyOf(functions); @@ -48,7 +52,8 @@ public class Model { ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (ExpressionFunction function : functions) { try { - contextBuilder.put(function.getName(), new LazyArrayContext(function.getBody(), referencedFunctions, this)); + contextBuilder.put(function.getName(), + new LazyArrayContext(function.getBody(), referencedFunctions, constants, this)); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index dacf20b7ef2..a0b859bf930 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -3,8 +3,11 @@ package ai.vespa.models.evaluation; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableMap; +import com.google.inject.Inject; import com.yahoo.component.AbstractComponent; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.util.Map; import java.util.stream.Collectors; @@ -21,8 +24,15 @@ public class ModelsEvaluator extends AbstractComponent { private final ImmutableMap<String, Model> models; - public ModelsEvaluator(RankProfilesConfig config) { - models = ImmutableMap.copyOf(new RankProfilesConfigImporter().importFrom(config)); + @Inject + public ModelsEvaluator(RankProfilesConfig config, + RankingConstantsConfig constantsConfig, + FileAcquirer fileAcquirer) { + this(new RankProfilesConfigImporter(fileAcquirer).importFrom(config, constantsConfig)); + } + + public ModelsEvaluator(Map<String, Model> models) { + this.models = ImmutableMap.copyOf(models); } /** Returns the models of this as an immutable map */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index bfd6342218a..87ac53488db 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,33 +1,52 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.yahoo.config.FileReference; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.TimeUnit; /** - * Converts RankProfilesConfig instances to RankingExpressions for evaluation + * Converts RankProfilesConfig instances to RankingExpressions for evaluation. + * This class can be used by a single thread only. * * @author bratseth */ -class RankProfilesConfigImporter { +public class RankProfilesConfigImporter { + + private final FileAcquirer fileAcquirer; + + public RankProfilesConfigImporter(FileAcquirer fileAcquirer) { + this.fileAcquirer = fileAcquirer; + } /** * Returns a map of the models contained in this config, indexed on name. * The map is modifiable and owned by the caller. */ - Map<String, Model> importFrom(RankProfilesConfig config) { + public Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { try { Map<String, Model> models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { - Model model = importProfile(profile); + Model model = importProfile(profile, constantsConfig); models.put(model.name(), model); } return models; @@ -37,11 +56,15 @@ class RankProfilesConfigImporter { } } - private Model importProfile(RankProfilesConfig.Rankprofile profile) throws ParseException { + private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) + throws ParseException { List<ExpressionFunction> functions = new ArrayList<>(); Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>(); ExpressionFunction firstPhase = null; ExpressionFunction secondPhase = null; + + List<Constant> constants = readConstants(constantsConfig); + for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); if ( reference.isPresent()) { @@ -52,7 +75,8 @@ class RankProfilesConfigImporter { functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); // // Make all functions, bound or not available under the name they are referenced by in expressions - referencedFunctions.put(reference.get(), new ExpressionFunction(reference.get().serialForm(), arguments, expression)); + referencedFunctions.put(reference.get(), + new ExpressionFunction(reference.get().serialForm(), arguments, expression)); } else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to macros firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(), @@ -69,7 +93,7 @@ class RankProfilesConfigImporter { functions.add(secondPhase); try { - return new Model(profile.name(), functions, referencedFunctions); + return new Model(profile.name(), functions, referencedFunctions, constants); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); @@ -83,4 +107,35 @@ class RankProfilesConfigImporter { return null; } + private List<Constant> readConstants(RankingConstantsConfig constantsConfig) { + List<Constant> constants = new ArrayList<>(); + + for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) { + constants.add(new Constant(constantConfig.name(), + readTensorFromFile(constantConfig.name(), + TensorType.fromSpec(constantConfig.type()), + constantConfig.fileref()))); + } + return constants; + } + + protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) { + try { + File file = fileAcquirer.waitFor(fileReference, 7, TimeUnit.DAYS); + if (file.getName().endsWith(".tbf")) + return TypedBinaryFormat.decode(Optional.of(type), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(file))); + else + throw new IllegalArgumentException("Constant files on other formats than .tbf are not supported, got " + + file + " for constant " + name); + // TODO: Support json and json.lz4 + } + catch (InterruptedException e) { + throw new IllegalStateException("Gave up waiting for constant " + name); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java new file mode 100644 index 00000000000..6e55c0c9a53 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -0,0 +1,74 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import org.junit.Test; + +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; + +/** + * Tests instantiating models from rank-profiles configs. + * + * @author bratseth + */ +public class MlModelsImportingTest { + + private static final double delta = 0.00000000001; + + @Test + public void testImportingModels() { + ModelTester tester = new ModelTester("src/test/resources/config/models/"); + + assertEquals(4, tester.models().size()); + + // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that + { + Model xgboost = tester.models().get("xgboost_2_2"); + tester.assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + FunctionEvaluator evaluator = xgboost.evaluatorOf(); + assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta); + } + + { + + Model onnxMnistSoftmax = tester.models().get("mnist_softmax"); + tester.assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10],d2[784])", + onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); + FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); + } + + { + Model tfMnistSoftmax = tester.models().get("mnist_softmax_saved"); + tester.assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); + } + + { + Model tfMnist = tester.models().get("mnist_saved"); + tester.assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + // Macro: + tester.assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add", + "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", + tfMnist); + FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument + assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(-0.714629131972222, evaluator.evaluate().sum().asDouble(), delta); + } + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java new file mode 100644 index 00000000000..0aceaccc3e0 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -0,0 +1,94 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.config.FileReference; +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.FileSource; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * Helper for testing model import and evaluation + * + * @author bratseth + */ +public class ModelTester { + + private final Map<String, Model> models; + + public ModelTester(String modelConfigDirectory) { + models = createModels(modelConfigDirectory); + } + + public Map<String, Model> models() { return models; } + + private static Map<String, Model> createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"), MockFileAcquirer.returnFile(null)) + .importFrom(config, constantsConfig); + } + + public void assertFunction(String name, String expression, Model model) { + assertNotNull("Model is present in config", model); + ExpressionFunction function = model.function(name); + assertNotNull("Function '" + name + "' is in " + model, function); + assertEquals(name, function.getName()); + assertEquals(expression, function.getBody().getRoot().toString()); + } + + public void assertBoundFunction(String name, String expression, Model model) { + ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get()); + assertNotNull("Function '" + name + "' is present", function); + assertEquals(name, function.getName()); + assertEquals(expression, function.getBody().getRoot().toString()); + } + + /** Allows us to provide canned tensor constants during import since file distribution does not work in tests */ + private static class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter { + + private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName()); + + private final Path constantsPath; + + public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) { + super(fileAcquirer); + this.constantsPath = constantsPath; + } + + @Override + protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) { + try { + return TypedBinaryFormat.decode(Optional.of(type), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantsPath.append(name).toFile()))); + } + catch (IOException e) { + log.warning("Missing a mocked tensor constant for '" + name + "': " + e.getMessage() + + ". Returning an empty tensor"); + return Tensor.from(type, "{}"); + } + } + + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index 60cf0d25ded..bd1ff6b8ed7 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -3,12 +3,13 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; +import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; -import java.io.File; - import static org.junit.Assert.assertEquals; /** @@ -18,15 +19,9 @@ public class ModelsEvaluatorTest { private static final double delta = 0.00000000001; - private ModelsEvaluator createModels() { - String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; - RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); - return new ModelsEvaluator(config); - } - @Test public void testTensorEvaluation() { - ModelsEvaluator models = createModels(); + ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum"); function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}")); function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}")); @@ -35,7 +30,7 @@ public class ModelsEvaluatorTest { @Test public void testEvaluationDependingOnMacroTakingArguments() { - ModelsEvaluator models = createModels(); + ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); function.bind("match", 3); function.bind("rankBoost", 5); @@ -46,6 +41,14 @@ public class ModelsEvaluatorTest { // TODO: Test that binding nonexisting variable doesn't work // TODO: Test that rebinding doesn't work // TODO: Test with nested macros - // TODO: Test TF/ONNX model + + private ModelsEvaluator createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new ModelsEvaluator(config, constantsConfig, MockFileAcquirer.returnFile(null)); + } } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java new file mode 100644 index 00000000000..20abd9c0fb0 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java @@ -0,0 +1,34 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class RankProfileImportingTest { + + @Test + public void testImportingRankExpressions() { + ModelTester tester = new ModelTester("src/test/resources/config/rankexpression/"); + + assertEquals(18, tester.models().size()); + + Model macros = tester.models().get("macros"); + assertEquals("macros", macros.name()); + assertEquals(4, macros.functions().size()); + tester.assertFunction("fourtimessum", "4 * (var1 + var2)", macros); + tester.assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros); + tester.assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros); + tester.assertFunction("myfeature", + "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " + + "30 * pow(0 - fieldMatch(description).earliness,2)", + macros); + assertEquals(4, macros.referencedFunctions().size()); + tester.assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", + "4 * (match + rankBoost)", macros); + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java deleted file mode 100644 index d45372fc7da..00000000000 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.models.evaluation; - -import com.yahoo.config.subscription.ConfigGetter; -import com.yahoo.config.subscription.FileSource; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.vespa.config.search.RankProfilesConfig; -import org.junit.Test; - -import java.io.File; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - -/** - * Tests instantiating models from rank-profiles configs. - * - * @author bratseth - */ -public class RankProfilesImporterTest { - - @Test - public void testImporting() { - String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; - RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); - Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config); - assertEquals(18, models.size()); - - Model macros = models.get("macros"); - assertNotNull(macros); - assertEquals("macros", macros.name()); - assertEquals(4, macros.functions().size()); - assertFunction("fourtimessum", "4 * (var1 + var2)", macros); - assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros); - assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros); - assertFunction("myfeature", - "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " + - "30 * pow(0 - fieldMatch(description).earliness,2)", - macros); - assertEquals(4, macros.referencedFunctions().size()); - assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", - "4 * (match + rankBoost)", macros); - } - - private void assertFunction(String name, String expression, Model model) { - ExpressionFunction function = model.function(name); - assertNotNull(function); - assertEquals(name, function.getName()); - assertEquals(expression, function.getBody().getRoot().toString()); - } - - private void assertBoundFunction(String name, String expression, Model model) { - ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get()); - assertNotNull("Function '" + name + "' is present", function); - assertEquals(name, function.getName()); - assertEquals(expression, function.getBody().getRoot().toString()); - } - -} diff --git a/model-evaluation/src/test/resources/config/models/constants/README b/model-evaluation/src/test/resources/config/models/constants/README new file mode 100644 index 00000000000..4a274aa95c8 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/README @@ -0,0 +1 @@ +These constants was created by writing TypedBinaryFormat.encode(tensor) on each large constant produced by these models. diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_bias_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_bias_read Binary files differnew file mode 100644 index 00000000000..bac75f7b1e7 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_bias_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_weights_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_weights_read Binary files differnew file mode 100644 index 00000000000..bd3f05be826 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden1_weights_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_bias_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_bias_read Binary files differnew file mode 100644 index 00000000000..fca7c76df3f --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_bias_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_weights_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_weights_read Binary files differnew file mode 100644 index 00000000000..396dea8f4bc --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_hidden2_weights_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_bias_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_bias_read Binary files differnew file mode 100644 index 00000000000..42f85478c10 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_bias_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_weights_read b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_weights_read Binary files differnew file mode 100644 index 00000000000..a3cc7d765f6 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_saved_dnn_outputs_weights_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable Binary files differnew file mode 100644 index 00000000000..e768328bff5 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable_1 b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable_1 Binary files differnew file mode 100644 index 00000000000..4fa0eadb0d3 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_Variable_1 diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read Binary files differnew file mode 100644 index 00000000000..4fa0eadb0d3 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_1_read diff --git a/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read Binary files differnew file mode 100644 index 00000000000..e768328bff5 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/constants/mnist_softmax_saved_layer_Variable_read diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg new file mode 100644 index 00000000000..1cc36f75158 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg @@ -0,0 +1,14 @@ +rankprofile[0].name "mnist_saved" +rankprofile[0].fef.property[0].name "rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add).rankingScript" +rankprofile[0].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))" +rankprofile[0].fef.property[1].name "rankingExpression(serving_default.y).rankingScript" +rankprofile[0].fef.property[1].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))" +rankprofile[1].name "xgboost_2_2" +rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript" +rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)" +rankprofile[2].name "mnist_softmax_saved" +rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript" +rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))" +rankprofile[3].name "mnist_softmax" +rankprofile[3].fef.property[0].name "rankingExpression(default.add).rankingScript" +rankprofile[3].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))" diff --git a/model-evaluation/src/test/resources/config/models/ranking-constants.cfg b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg new file mode 100644 index 00000000000..2b7495ace5e --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg @@ -0,0 +1,30 @@ +constant[0].name "mnist_saved_dnn_hidden1_weights_read" +constant[0].fileref "" +constant[0].type "tensor(d3[300],d4[784])" +constant[1].name "mnist_saved_dnn_hidden2_weights_read" +constant[1].fileref "" +constant[1].type "tensor(d2[100],d3[300])" +constant[2].name "mnist_softmax_saved_layer_Variable_1_read" +constant[2].fileref "" +constant[2].type "tensor(d1[10])" +constant[3].name "mnist_saved_dnn_hidden1_bias_read" +constant[3].fileref "" +constant[3].type "tensor(d3[300])" +constant[4].name "mnist_saved_dnn_hidden2_bias_read" +constant[4].fileref "" +constant[4].type "tensor(d2[100])" +constant[5].name "mnist_softmax_Variable" +constant[5].fileref "" +constant[5].type "tensor(d1[10],d2[784])" +constant[6].name "mnist_saved_dnn_outputs_weights_read" +constant[6].fileref "" +constant[6].type "tensor(d1[10],d2[100])" +constant[7].name "mnist_softmax_saved_layer_Variable_read" +constant[7].fileref "" +constant[7].type "tensor(d1[10],d2[784])" +constant[8].name "mnist_softmax_Variable_1" +constant[8].fileref "" +constant[8].type "tensor(d1[10])" +constant[9].name "mnist_saved_dnn_outputs_bias_read" +constant[9].fileref "" +constant[9].type "tensor(d1[10])"
\ No newline at end of file diff --git a/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg diff --git a/node-admin/pom.xml b/node-admin/pom.xml index 7daeacec463..64958554f53 100644 --- a/node-admin/pom.xml +++ b/node-admin/pom.xml @@ -18,6 +18,7 @@ <name>${project.artifactId}</name> <dependencies> + <!-- Provided --> <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>docker-api</artifactId> @@ -32,49 +33,55 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>node-repository</artifactId> + <artifactId>defaults</artifactId> <version>${project.version}</version> + <scope>provided</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>defaults</artifactId> + <artifactId>container-dev</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>container-dev</artifactId> + <artifactId>vespa-athenz</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> + + <!-- Compile --> <dependency> - <groupId>net.jpountz.lz4</groupId> - <artifactId>lz4</artifactId> + <groupId>com.yahoo.vespa</groupId> + <artifactId>orchestrator-restapi</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>node-repository</artifactId> + <version>${project.version}</version> <scope>compile</scope> </dependency> <dependency> <groupId>org.apache.httpcomponents</groupId> <artifactId>httpcore</artifactId> <version>4.4.1</version> + <scope>compile</scope> </dependency> <dependency> <groupId>org.apache.httpcomponents</groupId> <artifactId>httpclient</artifactId> <version>4.5</version> - </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>orchestrator-restapi</artifactId> - <version>${project.version}</version> <scope>compile</scope> </dependency> <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>vespa-athenz</artifactId> - <version>${project.version}</version> - <scope>provided</scope> + <groupId>org.apache.velocity</groupId> + <artifactId>velocity</artifactId> + <scope>compile</scope> </dependency> + <!-- Test --> <dependency> <groupId>org.hamcrest</groupId> <artifactId>hamcrest-junit</artifactId> @@ -82,6 +89,11 @@ <scope>test</scope> </dependency> <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <scope>test</scope> @@ -89,24 +101,24 @@ <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>application</artifactId> - <scope>test</scope> <version>${project.version}</version> + <scope>test</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>orchestrator</artifactId> + <artifactId>application-model</artifactId> <version>${project.version}</version> <scope>test</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>service-monitor</artifactId> + <artifactId>orchestrator</artifactId> <version>${project.version}</version> <scope>test</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>application-model</artifactId> + <artifactId>service-monitor</artifactId> <version>${project.version}</version> <scope>test</scope> </dependency> @@ -116,16 +128,6 @@ <version>${project.version}</version> <scope>test</scope> </dependency> - <dependency> - <groupId>org.mockito</groupId> - <artifactId>mockito-core</artifactId> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.apache.velocity</groupId> - <artifactId>velocity</artifactId> - <scope>compile</scope> - </dependency> </dependencies> <build> <plugins> diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java index f5f0fa5a3f1..7036f6852fe 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java @@ -641,6 +641,7 @@ public class NodeSpec { public Builder updateFromNodeAttributes(NodeAttributes attributes) { attributes.getDockerImage().ifPresent(this::currentDockerImage); + attributes.getCurrentOsVersion().ifPresent(this::currentOsVersion); attributes.getHardwareDivergence().ifPresent(this::hardwareDivergence); attributes.getRebootGeneration().ifPresent(this::currentRebootGeneration); attributes.getRestartGeneration().ifPresent(this::currentRestartGeneration); diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java index f3b5dc9342a..e558cb5bdb2 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java @@ -87,33 +87,30 @@ public class DockerOperationsImpl implements DockerOperations { .withUlimit("nproc", 32_768, 409_600) .withUlimit("core", -1, -1) .withAddCapability("SYS_PTRACE") // Needed for gcore, pstack etc. - .withAddCapability("SYS_ADMIN") // Needed for perf - - // TODO: Fix. Run containers as privileged in AWS because mapped directories are on another device - .withPrivileged(environment.getCloud().equalsIgnoreCase("aws")); + .withAddCapability("SYS_ADMIN"); // Needed for perf if (environment.getNodeType() == NodeType.confighost || environment.getNodeType() == NodeType.proxyhost) { command.withVolume("/var/lib/sia", "/var/lib/sia"); } + if (environment.getNodeType() == NodeType.proxyhost) { + command.withVolume("/opt/yahoo/share/ssl/certs/", "/opt/yahoo/share/ssl/certs/"); + } + if (environment.getNodeType() == NodeType.host) { Path zpePathInNode = environment.pathInNodeUnderVespaHome("var/zpe"); if (environment.isRunningOnHost()) { - command.withVolume("/var/zpe", zpePathInNode.toString()); + command.withSharedVolume("/var/zpe", zpePathInNode.toString()); } else { command.withVolume(environment.pathInHostFromPathInNode(containerName, zpePathInNode).toString(), zpePathInNode.toString()); } } - if (environment.getNodeType() == NodeType.proxyhost) { - command.withVolume("/opt/yahoo/share/ssl/certs/", "/opt/yahoo/share/ssl/certs/"); - } - if (!docker.networkNATed()) { command.withIpAddress(ipV6Address); command.withNetworkMode(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME); - command.withVolume("/etc/hosts", "/etc/hosts"); + command.withSharedVolume("/etc/hosts", "/etc/hosts"); } else { InetAddress ipV6Prefix = InetAddresses.forString(IPV6_NPT_PREFIX); InetAddress ipV6Local = IPAddresses.prefixTranslate(ipV6Address, ipV6Prefix, 8); @@ -368,9 +365,6 @@ public class DockerOperationsImpl implements DockerOperations { directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/db/vespa"), false); directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/jdisc_container"), false); directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/jdisc_core"), false); - if (environment.getNodeType() == NodeType.host) { - directoriesToMount.put(Paths.get("/var/lib/sia"), true); - } directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/maven"), false); directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/run"), false); directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/scoreboards"), true); @@ -385,6 +379,8 @@ public class DockerOperationsImpl implements DockerOperations { directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/container-data"), false); if (environment.getNodeType() == NodeType.proxyhost) directoriesToMount.put(environment.pathInNodeUnderVespaHome("var/vespa-hosted/routing"), true); + if (environment.getNodeType() == NodeType.host) + directoriesToMount.put(Paths.get("/var/lib/sia"), true); return Collections.unmodifiableMap(directoriesToMount); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java index f82047d885c..3871bb82313 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java @@ -1,6 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.maintenance.identity; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SslContextBuilder; +import com.yahoo.security.X509CertificateUtils; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; import com.yahoo.vespa.athenz.client.zts.InstanceIdentity; @@ -13,12 +18,6 @@ import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.client.DefaultIdentityDocumentClient; import com.yahoo.vespa.athenz.identityprovider.client.InstanceCsrGenerator; import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.tls.Pkcs10Csr; -import com.yahoo.vespa.athenz.tls.SslContextBuilder; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; import com.yahoo.vespa.athenz.utils.SiaUtils; import com.yahoo.vespa.hosted.dockerapi.ContainerName; import com.yahoo.vespa.hosted.node.admin.component.Environment; @@ -169,10 +168,11 @@ public class AthenzCredentialsMaintainer { return now.isAfter(expiry.minus(EXPIRY_MARGIN)); } + @SuppressWarnings("deprecation") private void registerIdentity() { KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); SignedIdentityDocument signedIdentityDocument = identityDocumentClient.getNodeIdentityDocument(hostname); - Pkcs10Csr csr = csrGenerator.generateCsr( + com.yahoo.vespa.athenz.tls.Pkcs10Csr csr = csrGenerator.generateCsr( containerIdentity, signedIdentityDocument.providerUniqueId(), signedIdentityDocument.ipAddresses(), keyPair); try (ZtsClient ztsClient = new DefaultZtsClient(ztsEndpoint, hostIdentityProvider)) { InstanceIdentity instanceIdentity = @@ -191,14 +191,15 @@ public class AthenzCredentialsMaintainer { } } + @SuppressWarnings("deprecation") private void refreshIdentity() { SignedIdentityDocument identityDocument = EntityBindingsMapper.readSignedIdentityDocumentFromFile(identityDocumentFile); KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); - Pkcs10Csr csr = csrGenerator.generateCsr(containerIdentity, identityDocument.providerUniqueId(), identityDocument.ipAddresses(), keyPair); + com.yahoo.vespa.athenz.tls.Pkcs10Csr csr = csrGenerator.generateCsr(containerIdentity, identityDocument.providerUniqueId(), identityDocument.ipAddresses(), keyPair); SSLContext containerIdentitySslContext = new SslContextBuilder() - .withKeyStore(privateKeyFile.toFile(), certificateFile.toFile()) - .withTrustStore(trustStorePath.toFile(), KeyStoreType.JKS) + .withKeyStore(privateKeyFile, certificateFile) + .withTrustStore(trustStorePath, KeyStoreType.JKS) .build(); try { try (ZtsClient ztsClient = new DefaultZtsClient(ztsEndpoint, containerIdentity, containerIdentitySslContext)) { diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdmin.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdmin.java index a0657c3d34c..16992bcb13a 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdmin.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdmin.java @@ -6,7 +6,6 @@ import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeSpec; import java.time.Duration; import java.util.List; import java.util.Map; -import java.util.Set; /** * NodeAdmin manages the life cycle of NodeAgents. diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java index 96e1461bc32..ba8a2e55587 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java @@ -157,8 +157,8 @@ public class NodeAdminImpl implements NodeAdmin { Map<String, Object> debug = new LinkedHashMap<>(); debug.put("isFrozen", isFrozen); - List<Map<String, Object>> nodeAgentDebugs = nodeAgentsByHostname.entrySet().stream() - .map(node -> node.getValue().debugInfo()).collect(Collectors.toList()); + List<Map<String, Object>> nodeAgentDebugs = nodeAgentsByHostname.values().stream() + .map(NodeAgent::debugInfo).collect(Collectors.toList()); debug.put("NodeAgents", nodeAgentDebugs); return debug; } @@ -171,7 +171,7 @@ public class NodeAdminImpl implements NodeAdmin { } catch (Throwable e) { logger.warning("Metric fetcher scheduler failed", e); } - }, 0, 55, TimeUnit.SECONDS); + }, 10, 55, TimeUnit.SECONDS); int delay = 120; // WARNING: Reducing this will increase the load on config servers. aclScheduler.scheduleWithFixedDelay(() -> { diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java index 5f2093c4719..7c84150009e 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java @@ -26,14 +26,11 @@ import com.yahoo.vespa.hosted.node.admin.maintenance.identity.AthenzCredentialsM import com.yahoo.vespa.hosted.node.admin.util.PrefixLogger; import com.yahoo.vespa.hosted.provision.Node; -import java.text.SimpleDateFormat; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; -import java.util.Date; import java.util.LinkedHashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -80,9 +77,6 @@ public class NodeAgentImpl implements NodeAgent { private final Duration timeBetweenEachConverge; private final AthenzCredentialsMaintainer athenzCredentialsMaintainer; - private final SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); - private final LinkedList<String> debugMessages = new LinkedList<>(); - private int numberOfUnhandledException = 0; private Instant lastConverge; @@ -155,7 +149,7 @@ public class NodeAgentImpl implements NodeAgent { synchronized (monitor) { if (wantFrozen != frozen) { wantFrozen = frozen; - addDebugMessage(wantFrozen ? "Freezing" : "Unfreezing"); + logger.debug(wantFrozen ? "Freezing" : "Unfreezing"); signalWorkToBeDone(); } @@ -163,17 +157,6 @@ public class NodeAgentImpl implements NodeAgent { } } - private void addDebugMessage(String message) { - synchronized (debugMessages) { - while (debugMessages.size() > 1000) { - debugMessages.pop(); - } - - logger.debug(message); - debugMessages.add("[" + sdf.format(new Date()) + "] " + message); - } - } - @Override public Map<String, Object> debugInfo() { Map<String, Object> debug = new LinkedHashMap<>(); @@ -182,18 +165,13 @@ public class NodeAgentImpl implements NodeAgent { debug.put("wantFrozen", wantFrozen); debug.put("terminated", terminated); debug.put("workToDoNow", workToDoNow); - synchronized (debugMessages) { - debug.put("history", new LinkedList<>(debugMessages)); - } debug.put("nodeRepoState", lastNode.getState().name()); return debug; } @Override public void start() { - String message = "Starting with interval " + timeBetweenEachConverge.toMillis() + " ms"; - logger.info(message); - addDebugMessage(message); + logger.info("Starting with interval " + timeBetweenEachConverge.toMillis() + " ms"); loopThread.start(); @@ -213,7 +191,6 @@ public class NodeAgentImpl implements NodeAgent { @Override public void stop() { - addDebugMessage("Stopping"); filebeatRestarter.shutdown(); if (!terminated.compareAndSet(false, true)) { throw new RuntimeException("Can not re-stop a node agent."); @@ -240,7 +217,7 @@ public class NodeAgentImpl implements NodeAgent { currentFilebeatRestarter = Optional.of(filebeatRestarter.scheduleWithFixedDelay( () -> serviceRestarter.accept("filebeat"), 1, 1, TimeUnit.DAYS)); - addDebugMessage("Starting optional node program resume command"); + logger.debug("Starting optional node program resume command"); dockerOperations.resumeNode(containerName); resumeScriptRun = true; } @@ -266,8 +243,6 @@ public class NodeAgentImpl implements NodeAgent { if (!currentAttributes.equals(wantedAttributes)) { logger.info("Publishing new set of attributes to node repo: " + currentAttributes + " -> " + wantedAttributes); - addDebugMessage("Publishing new set of attributes to node repo: {" + - currentAttributes + "} -> {" + wantedAttributes + "}"); nodeRepository.updateNodeAttributes(hostname, wantedAttributes); } } @@ -386,7 +361,7 @@ public class NodeAgentImpl implements NodeAgent { synchronized (monitor) { if (!workToDoNow) { workToDoNow = true; - addDebugMessage("Signaling work to be done"); + logger.debug("Signaling work to be done"); monitor.notifyAll(); } } @@ -421,21 +396,19 @@ public class NodeAgentImpl implements NodeAgent { boolean converged = false; if (isFrozenCopy) { - addDebugMessage("tick: isFrozen"); + logger.debug("tick: isFrozen"); } else { try { converge(); converged = true; } catch (OrchestratorException e) { logger.info(e.getMessage()); - addDebugMessage(e.getMessage()); } catch (DockerException e) { numberOfUnhandledException++; logger.error("Caught a DockerException, resetting containerState to " + containerState, e); } catch (Exception e) { numberOfUnhandledException++; logger.error("Unhandled exception, ignoring.", e); - addDebugMessage(e.getMessage()); } } @@ -462,7 +435,7 @@ public class NodeAgentImpl implements NodeAgent { storageMaintainer.writeMetricsConfig(containerName, node); } - addDebugMessage("Loading new node spec: " + node.toString()); + logger.debug("Loading new node spec: " + node.toString()); lastNode = node; } @@ -484,7 +457,7 @@ public class NodeAgentImpl implements NodeAgent { scheduleDownLoadIfNeeded(node); if (isDownloadingImage()) { - addDebugMessage("Waiting for image to download " + imageBeingDownloaded.asString()); + logger.debug("Waiting for image to download " + imageBeingDownloaded.asString()); return; } container = removeContainerIfNeededUpdateContainerState(node, container); diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepo.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepo.java deleted file mode 100644 index 5df790f9105..00000000000 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepo.java +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.node.admin.task.util.yum; - -import com.yahoo.vespa.hosted.node.admin.component.TaskContext; -import com.yahoo.vespa.hosted.node.admin.task.util.file.FileWriter; - -import java.nio.file.FileSystem; -import java.nio.file.FileSystems; -import java.nio.file.Path; -import java.util.regex.Pattern; - -/** - * @author hakonhall - */ -public class AddYumRepo { - private static final Pattern REPOSITORY_ID_PATTERN = Pattern.compile("^[a-zA-Z0-9_-]+$"); - - private final String repositoryId; // e.g. "platform_rpms-latest" - private final String name; // e.g. "Platform RPM Latest Repo" - private final String baseurl; - private final boolean enabled; - private final FileSystem fileSystem; - - public AddYumRepo(String repositoryId, - String name, - String baseurl, - boolean enabled) { - this(repositoryId, name, baseurl, enabled, FileSystems.getDefault()); - } - - public boolean converge(TaskContext context) { - Path path = fileSystem.getPath("/etc/yum.repos.d",repositoryId + ".repo"); - - FileWriter fileWriter = new FileWriter(path, this::getRepoFileContent) - .withOwner("root") - .withGroup("root") - .withPermissions("rw-r--r--") - .onlyIfFileDoesNotAlreadyExist(); - - return fileWriter.converge(context); - } - - private String getRepoFileContent() { - return String.join("\n", - "# This file was generated by node admin", - "# Do NOT modify this file by hand", - "", - "[" + repositoryId + "]", - "name=" + name, - "baseurl=" + baseurl, - "enabled=" + (enabled ? 1 : 0), - "gpgcheck=0" - ) + "\n"; - } - - private static void validateRepositoryId(String repositoryId) { - if (!REPOSITORY_ID_PATTERN.matcher(repositoryId).matches()) { - throw new IllegalArgumentException("Invalid repository ID '" + repositoryId + "'"); - } - } - - // For testing - public AddYumRepo(String repositoryId, - String name, - String baseurl, - boolean enabled, - FileSystem fileSystem) { - this.repositoryId = repositoryId; - this.name = name; - this.baseurl = baseurl; - this.enabled = enabled; - this.fileSystem = fileSystem; - validateRepositoryId(repositoryId); - } -} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java index 35e3d1f8b9e..3a9c49a0f2d 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java @@ -5,6 +5,7 @@ import com.yahoo.vespa.hosted.node.admin.component.TaskContext; import com.yahoo.vespa.hosted.node.admin.task.util.process.CommandLine; import com.yahoo.vespa.hosted.node.admin.task.util.process.Terminal; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -162,7 +163,7 @@ public class Yum { private final List<YumPackageName> packages; private final Pattern commandOutputNoopPattern; - private Optional<String> enabledRepo = Optional.empty(); + private final List<String> enabledRepo = new ArrayList<>(); private GenericYumCommand(Terminal terminal, String yumCommand, @@ -179,15 +180,15 @@ public class Yum { } @SuppressWarnings("unchecked") - public GenericYumCommand enableRepo(String repo) { - enabledRepo = Optional.of(repo); + public GenericYumCommand enableRepos(String... repos) { + enabledRepo.addAll(Arrays.asList(repos)); return this; } public boolean converge(TaskContext context) { CommandLine commandLine = terminal.newCommandLine(context); commandLine.add("yum", yumCommand, "--assumeyes"); - enabledRepo.ifPresent(repo -> commandLine.add("--enablerepo=" + repo)); + enabledRepo.forEach(repo -> commandLine.add("--enablerepo=" + repo)); commandLine.add(packages.stream().map(YumPackageName::toName).collect(Collectors.toList())); // There's no way to figure out whether a yum command would have been a no-op. diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerMock.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerMock.java index 9b9bb2af26c..4b4ef05593d 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerMock.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerMock.java @@ -167,6 +167,11 @@ public class DockerMock implements Docker { } @Override + public CreateContainerCommand withSharedVolume(String path, String volumePath) { + return this; + } + + @Override public CreateContainerCommand withNetworkMode(String mode) { return this; } diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java index 603ad3ebccf..d0e4377ffc5 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/integrationTests/DockerTester.java @@ -14,7 +14,6 @@ import com.yahoo.vespa.hosted.node.admin.docker.DockerOperations; import com.yahoo.vespa.hosted.node.admin.docker.DockerOperationsImpl; import com.yahoo.vespa.hosted.node.admin.maintenance.acl.AclMaintainer; import com.yahoo.vespa.hosted.node.admin.maintenance.identity.AthenzCredentialsMaintainer; -import com.yahoo.vespa.hosted.node.admin.nodeadmin.NodeAdmin; import com.yahoo.vespa.hosted.node.admin.nodeadmin.NodeAdminImpl; import com.yahoo.vespa.hosted.node.admin.nodeadmin.NodeAdminStateUpdaterImpl; import com.yahoo.vespa.hosted.node.admin.nodeagent.NodeAgent; diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepoTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepoTest.java deleted file mode 100644 index c6314439003..00000000000 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepoTest.java +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.vespa.hosted.node.admin.task.util.yum; - -import com.yahoo.vespa.hosted.node.admin.component.TaskContext; -import com.yahoo.vespa.hosted.node.admin.task.util.file.UnixPath; -import com.yahoo.vespa.test.file.TestFileSystem; -import org.junit.Test; - -import java.nio.file.FileSystem; -import java.time.Instant; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; - -public class AddYumRepoTest { - @Test - public void converge() { - String repositoryId = "repoid"; - String name = "name"; - String baseurl = "http://foo.com/bar"; - boolean enabled = true; - - FileSystem fileSystem = TestFileSystem.create(); - AddYumRepo addYumRepo = new AddYumRepo( - repositoryId, - name, - baseurl, - enabled, - fileSystem); - - TaskContext context = mock(TaskContext.class); - - assertTrue(addYumRepo.converge(context)); - - UnixPath unixPath = new UnixPath(fileSystem.getPath("/etc/yum.repos.d/" + repositoryId + ".repo")); - String content = unixPath.readUtf8File(); - assertEquals("# This file was generated by node admin\n" + - "# Do NOT modify this file by hand\n" + - "\n" + - "[repoid]\n" + - "name=name\n" + - "baseurl=http://foo.com/bar\n" + - "enabled=1\n" + - "gpgcheck=0\n", content); - Instant lastModifiedTime = unixPath.getLastModifiedTime(); - - // Second time is a no-op - assertFalse(addYumRepo.converge(context)); - assertEquals(lastModifiedTime, unixPath.getLastModifiedTime()); - } - -}
\ No newline at end of file diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java index 4f2d65bf522..2e65c1aae09 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java @@ -47,13 +47,13 @@ public class YumTest { @Test public void testAlreadyInstalled() { terminal.expectCommand( - "yum install --assumeyes --enablerepo=repo-name package-1 package-2 2>&1", + "yum install --assumeyes --enablerepo=repo1 --enablerepo=repo2 package-1 package-2 2>&1", 0, "foobar\nNothing to do\n"); assertFalse(yum .install("package-1", "package-2") - .enableRepo("repo-name") + .enableRepos("repo1", "repo2") .converge(taskContext)); } @@ -102,7 +102,7 @@ public class YumTest { assertTrue(yum .install("package-1", "package-2") - .enableRepo("repo-name") + .enableRepos("repo-name") .converge(taskContext)); } @@ -185,7 +185,7 @@ public class YumTest { "error"); yum.install("package-1", "package-2") - .enableRepo("repo-name") + .enableRepos("repo-name") .converge(taskContext); fail(); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java index 0ef5c03e543..06a86cbddf7 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java @@ -81,7 +81,8 @@ public class NodeRepositoryProvisioner implements Provisioner { int effectiveGroups; NodeSpec requestedNodes; if ( requestedCapacity.type() == NodeType.tenant) { - int nodeCount = capacityPolicies.decideSize(requestedCapacity); + int nodeCount = application.instance().isTester() ? 1 : capacityPolicies.decideSize(requestedCapacity); + if (zone.environment().isManuallyDeployed() && nodeCount < requestedCapacity.nodeCount()) logger.log(Level.INFO, "Requested " + requestedCapacity.nodeCount() + " nodes for " + cluster + ", downscaling to " + nodeCount + " nodes in " + zone.environment()); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifier.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifier.java index 90c24f6bb23..0891279f30c 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifier.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifier.java @@ -6,8 +6,8 @@ import com.google.common.base.Suppliers; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.Zone; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; -import com.yahoo.vespa.athenz.tls.SubjectAlternativeName; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; +import com.yahoo.security.SubjectAlternativeName; +import com.yahoo.security.X509CertificateUtils; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeRepository; @@ -16,7 +16,7 @@ import java.util.List; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; -import static com.yahoo.vespa.athenz.tls.SubjectAlternativeName.Type.DNS_NAME; +import static com.yahoo.security.SubjectAlternativeName.Type.DNS_NAME; /** * Resolve node from various types of x509 identity certificates. diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/FilterTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/FilterTester.java index 6420a5237e8..caecce1634d 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/FilterTester.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/FilterTester.java @@ -5,9 +5,12 @@ import com.yahoo.application.container.handler.Request.Method; import com.yahoo.container.jdisc.RequestHandlerTestDriver; import com.yahoo.jdisc.http.filter.DiscFilterRequest; import com.yahoo.jdisc.http.filter.SecurityRequestFilter; -import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateBuilder; import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; import java.net.URI; import java.security.KeyPair; import java.security.KeyPairGenerator; @@ -20,7 +23,8 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_ECDSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_RSA; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -65,7 +69,7 @@ public class FilterTester { when(r.getRemoteAddr()).thenReturn(request.remoteAddr()); when(r.getLocalAddr()).thenReturn(request.localAddr()); if (request.commonName().isPresent()) { - X509Certificate cert = certificateFor(request.commonName().get(), keyPair()); + X509Certificate cert = certificateFor(request.commonName().get(), KeyUtils.generateKeypair(KeyAlgorithm.EC)); List<X509Certificate> certs = Collections.singletonList(cert); when(r.getClientCertificateChain()).thenReturn(certs); when(r.getUserPrincipal()).thenReturn(NodePrincipal.withLegacyIdentity(request.commonName().get(), certs)); @@ -73,23 +77,13 @@ public class FilterTester { return r; } - /** Create a RSA public/private key pair */ - private static KeyPair keyPair() { - try { - KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); - keyGen.initialize(2048); - return keyGen.generateKeyPair(); - } catch (NoSuchAlgorithmException e) { - throw new RuntimeException(e); - } - } /** Create a self signed certificate for commonName using given public/private key pair */ private static X509Certificate certificateFor(String commonName, KeyPair keyPair) { Instant now = Instant.now(); X500Principal subject = new X500Principal("CN=" + commonName); return X509CertificateBuilder - .fromKeypair(keyPair, subject, now, now.plus(Duration.ofDays(30)), SHA256_WITH_RSA, now.toEpochMilli()) + .fromKeypair(keyPair, subject, now, now.plus(Duration.ofDays(30)), SHA256_WITH_ECDSA, BigInteger.valueOf(now.toEpochMilli())) .setBasicConstraints(true, true) .build(); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java index d02a666eb69..f7d4a9603e7 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java @@ -12,10 +12,10 @@ import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.Zone; import com.yahoo.config.provisioning.FlavorsConfig; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.tls.Pkcs10Csr; -import com.yahoo.vespa.athenz.tls.Pkcs10CsrBuilder; -import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.Pkcs10Csr; +import com.yahoo.security.Pkcs10CsrBuilder; +import com.yahoo.security.X509CertificateBuilder; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeRepositoryTester; import com.yahoo.vespa.hosted.provision.node.Allocation; @@ -26,14 +26,17 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; import java.security.KeyPair; import java.security.cert.X509Certificate; import java.time.Instant; import java.util.Optional; +import static com.yahoo.security.KeyAlgorithm.EC; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_ECDSA; import static com.yahoo.vespa.athenz.identityprovider.api.IdentityType.*; -import static com.yahoo.vespa.athenz.tls.KeyAlgorithm.RSA; -import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; +import static com.yahoo.security.KeyAlgorithm.RSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_RSA; import static com.yahoo.vespa.hosted.provision.restapi.v2.filter.NodeIdentifier.CONFIGSERVER_HOST_IDENTITY; import static com.yahoo.vespa.hosted.provision.restapi.v2.filter.NodeIdentifier.PROXY_HOST_IDENTITY; import static com.yahoo.vespa.hosted.provision.restapi.v2.filter.NodeIdentifier.TENANT_DOCKER_CONTAINER_IDENTITY; @@ -64,7 +67,7 @@ public class NodeIdentifierTest { private static final String INSTANCE_ID = "default"; private static final Zone ZONE = new Zone(SystemName.main, Environment.prod, RegionName.defaultName()); - private static final KeyPair KEYPAIR = KeyUtils.generateKeypair(RSA); + private static final KeyPair KEYPAIR = KeyUtils.generateKeypair(EC); private static final X509Certificate ATHENZ_YAHOO_CA_CERT = createDummyCaCertificate("Yahoo Athenz CA"); private static final X509Certificate ATHENZ_AWS_CA_CERT = createDummyCaCertificate("Athenz AWS CA"); @@ -73,7 +76,7 @@ public class NodeIdentifierTest { NodeRepositoryTester nodeRepositoryDummy = new NodeRepositoryTester(); X509Certificate certificate = X509CertificateBuilder .fromKeypair( - KEYPAIR, new X500Principal("CN=" + HOSTNAME), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_RSA, 1) + KEYPAIR, new X500Principal("CN=" + HOSTNAME), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_ECDSA, BigInteger.ONE) .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); expectedException.expect(NodeIdentifier.NodeIdentifierException.class); @@ -87,10 +90,10 @@ public class NodeIdentifierTest { nodeRepositoryDummy.addNode(OPENSTACK_ID, HOSTNAME, INSTANCE_ID, NodeType.host); nodeRepositoryDummy.setNodeState(HOSTNAME, Node.State.active); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(OPENSTACK_ID + ".instanceid.athenz.provider-name.ostk.yahoo.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -106,10 +109,10 @@ public class NodeIdentifierTest { nodeRepositoryDummy.addNode(AWS_INSTANCE_ID, HOSTNAME, INSTANCE_ID, NodeType.host); nodeRepositoryDummy.setNodeState(HOSTNAME, Node.State.active); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(AWS_INSTANCE_ID + ".instanceid.athenz.aws.oath.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -125,10 +128,10 @@ public class NodeIdentifierTest { nodeRepositoryDummy.addNode(AWS_INSTANCE_ID, PROXY_HOSTNAME, INSTANCE_ID, NodeType.proxyhost); nodeRepositoryDummy.setNodeState(PROXY_HOSTNAME, Node.State.active); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + PROXY_HOST_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + PROXY_HOST_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(AWS_INSTANCE_ID + ".instanceid.athenz.aws.oath.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -142,10 +145,10 @@ public class NodeIdentifierTest { public void accepts_aws_configserver_host_certificate() { NodeRepositoryTester nodeRepositoryDummy = new NodeRepositoryTester(); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + CONFIGSERVER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + CONFIGSERVER_HOST_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_AWS_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(AWS_INSTANCE_ID + ".instanceid.athenz.aws.oath.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -156,7 +159,7 @@ public class NodeIdentifierTest { @Test public void accepts_zts_certificate() { X509Certificate certificate = X509CertificateBuilder - .fromKeypair(KEYPAIR, new X500Principal("CN=" + ZTS_AWS_IDENTITY), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_RSA, 1) + .fromKeypair(KEYPAIR, new X500Principal("CN=" + ZTS_AWS_IDENTITY), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_ECDSA, BigInteger.ONE) .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, new NodeRepositoryTester().nodeRepository()); NodePrincipal identity = identifier.resolveNode(singletonList(certificate)); @@ -176,11 +179,11 @@ public class NodeIdentifierTest { Node node = createNode(clusterId, clusterIndex, tenant, application); nodeRepositoryDummy.nodeRepository().addDockerNodes(singletonList(node)); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_CONTAINER_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_CONTAINER_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); VespaUniqueInstanceId vespaUniqueInstanceId = new VespaUniqueInstanceId(clusterIndex, clusterId, INSTANCE_ID, application, tenant, region, environment, NODE); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(vespaUniqueInstanceId.asDottedString() + ".instanceid.athenz.provider-name.vespa.yahoo.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -194,10 +197,10 @@ public class NodeIdentifierTest { public void accepts_controller_certificate() { NodeRepositoryTester nodeRepositoryDummy = new NodeRepositoryTester(); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + CONTROLLER_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + CONTROLLER_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); NodePrincipal identity = identifier.resolveNode(singletonList(certificate)); @@ -211,10 +214,10 @@ public class NodeIdentifierTest { nodeRepositoryDummy.addNode(OPENSTACK_ID, HOSTNAME, INSTANCE_ID, NodeType.tenant); nodeRepositoryDummy.setNodeState(HOSTNAME, Node.State.active); Pkcs10Csr csr = Pkcs10CsrBuilder - .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_CONTAINER_IDENTITY), KEYPAIR, SHA256_WITH_RSA) + .fromKeypair(new X500Principal("CN=" + TENANT_DOCKER_CONTAINER_IDENTITY), KEYPAIR, SHA256_WITH_ECDSA) .build(); X509Certificate certificate = X509CertificateBuilder - .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) + .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_ECDSA, BigInteger.ONE) .addSubjectAlternativeName(OPENSTACK_ID + ".instanceid.athenz.ostk.yahoo.cloud") .build(); NodeIdentifier identifier = new NodeIdentifier(ZONE, nodeRepositoryDummy.nodeRepository()); @@ -251,10 +254,10 @@ public class NodeIdentifierTest { } private static X509Certificate createDummyCaCertificate(String caCommonName) { - KeyPair keyPair = KeyUtils.generateKeypair(RSA); + KeyPair keyPair = KeyUtils.generateKeypair(EC); return X509CertificateBuilder .fromKeypair( - keyPair, new X500Principal("CN=" + caCommonName), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_RSA, 1) + keyPair, new X500Principal("CN=" + caCommonName), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), SHA256_WITH_ECDSA, BigInteger.ONE) .setBasicConstraints(true, true) .build(); @@ -1,5 +1,5 @@ <?xml version="1.0" encoding="UTF-8"?> -<!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>com.yahoo.vespa</groupId> diff --git a/searchcommon/src/vespa/searchcommon/attribute/attributecontent.h b/searchcommon/src/vespa/searchcommon/attribute/attributecontent.h index 508a2d04c27..72ce1754d71 100644 --- a/searchcommon/src/vespa/searchcommon/attribute/attributecontent.h +++ b/searchcommon/src/vespa/searchcommon/attribute/attributecontent.h @@ -145,7 +145,7 @@ public: search::attribute::IAttributeVector::DocId docId) { uint32_t count = attribute.get(docId, data(), capacity()); - if (count > capacity()) { + while (count > capacity()) { allocate(count); count = attribute.get(docId, data(), capacity()); } diff --git a/searchcore/pom.xml b/searchcore/pom.xml index 3b43bf1205e..002ba1f508f 100644 --- a/searchcore/pom.xml +++ b/searchcore/pom.xml @@ -1,3 +1,4 @@ +<?xml version="1.0"?> <!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" @@ -11,8 +12,8 @@ <relativePath>../parent/pom.xml</relativePath> </parent> <artifactId>searchcore</artifactId> - <version>6-SNAPSHOT</version> <packaging>jar</packaging> + <version>6-SNAPSHOT</version> <name>${project.artifactId}</name> <dependencies> <dependency> diff --git a/searchcore/src/main/java/com/yahoo/vespa/config/search/core/package-info.java b/searchcore/src/main/java/com/yahoo/vespa/config/search/core/package-info.java new file mode 100644 index 00000000000..c29162d65ae --- /dev/null +++ b/searchcore/src/main/java/com/yahoo/vespa/config/search/core/package-info.java @@ -0,0 +1,7 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +@ExportPackage +package com.yahoo.vespa.config.search.core; + +import com.yahoo.osgi.annotation.ExportPackage; + diff --git a/searchcore/src/tests/proton/flushengine/CMakeLists.txt b/searchcore/src/tests/proton/flushengine/CMakeLists.txt index 826c9b2390f..6e8df3c9b7f 100644 --- a/searchcore/src/tests/proton/flushengine/CMakeLists.txt +++ b/searchcore/src/tests/proton/flushengine/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_executable(searchcore_flushengine_test_app TEST SOURCES - flushengine.cpp + flushengine_test.cpp DEPENDS searchcore_flushengine searchcore_pcommon diff --git a/searchcore/src/tests/proton/flushengine/flushengine.cpp b/searchcore/src/tests/proton/flushengine/flushengine_test.cpp index d1a98f1b7d3..f668072b9fd 100644 --- a/searchcore/src/tests/proton/flushengine/flushengine.cpp +++ b/searchcore/src/tests/proton/flushengine/flushengine_test.cpp @@ -3,15 +3,15 @@ #include <vespa/searchcore/proton/flushengine/cachedflushtarget.h> #include <vespa/searchcore/proton/flushengine/flush_engine_explorer.h> #include <vespa/searchcore/proton/flushengine/flushengine.h> +#include <vespa/searchcore/proton/flushengine/i_tls_stats_factory.h> #include <vespa/searchcore/proton/flushengine/threadedflushtarget.h> #include <vespa/searchcore/proton/flushengine/tls_stats_map.h> -#include <vespa/searchcore/proton/flushengine/i_tls_stats_factory.h> #include <vespa/searchcore/proton/server/igetserialnum.h> #include <vespa/searchcore/proton/test/dummy_flush_handler.h> #include <vespa/searchcore/proton/test/dummy_flush_target.h> -#include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/test/insertion_operators.h> +#include <vespa/vespalib/testkit/testapp.h> #include <mutex> #include <chrono> @@ -42,7 +42,6 @@ public: SimpleExecutor() : _done() { - // empty } Task::UP @@ -113,7 +112,7 @@ public: } }; -typedef std::vector<IFlushTarget::SP> Targets; +using Targets = std::vector<IFlushTarget::SP>; using FlushDoneHistory = std::vector<search::SerialNum>; @@ -141,7 +140,6 @@ public: _done(targets.size()), _flushDoneHistory() { - // empty } search::SerialNum @@ -219,7 +217,6 @@ public: : _flushedSerial(flushedSerial), _currentSerial(currentSerial), _start(start), _done(done), _proceed(proceed) { - // empty } void run() override { @@ -248,39 +245,45 @@ public: vespalib::Gate _taskDone; Task::UP _task; -public: - typedef std::shared_ptr<SimpleTarget> SP; - - SimpleTarget(Task::UP task, const std::string &name) : - test::DummyFlushTarget(name), - _flushedSerial(0), - _currentSerial(0), +protected: + SimpleTarget(const std::string &name, const Type &type, search::SerialNum flushedSerial = 0, bool proceedImmediately = true) : + test::DummyFlushTarget(name, type, Component::OTHER), + _flushedSerial(flushedSerial), _proceed(), _initDone(), _taskStart(), _taskDone(), - _task(std::move(task)) + _task(std::make_unique<SimpleTask>(_taskStart, _taskDone, &_proceed, + _flushedSerial, _currentSerial)) { + if (proceedImmediately) { + _proceed.countDown(); + } } - SimpleTarget(const std::string &name, search::SerialNum flushedSerial = 0, bool proceedImmediately = true) : +public: + using SP = std::shared_ptr<SimpleTarget>; + + SimpleTarget(Task::UP task, const std::string &name) : test::DummyFlushTarget(name), - _flushedSerial(flushedSerial), + _flushedSerial(0), + _currentSerial(0), _proceed(), _initDone(), _taskStart(), _taskDone(), - _task(new SimpleTask(_taskStart, _taskDone, &_proceed, - _flushedSerial, _currentSerial)) + _task(std::move(task)) { - if (proceedImmediately) { - _proceed.countDown(); - } } + SimpleTarget(search::SerialNum flushedSerial = 0, bool proceedImmediately = true) : SimpleTarget("anon", flushedSerial, proceedImmediately) { } + SimpleTarget(const std::string &name, search::SerialNum flushedSerial = 0, bool proceedImmediately = true) + : SimpleTarget(name, Type::OTHER, flushedSerial, proceedImmediately) + { } + virtual Time getLastFlushTime() const override { return fastos::ClockSystem::now(); } @@ -304,6 +307,13 @@ public: }; +class GCTarget : public SimpleTarget { +public: + GCTarget(const vespalib::string &name, search::SerialNum flushedSerial) + : SimpleTarget(name, Type::GC, flushedSerial) + {} +}; + class AssertedTarget : public SimpleTarget { public: mutable bool _mgain; @@ -366,10 +376,7 @@ public: public: typedef std::shared_ptr<SimpleStrategy> SP; - SimpleStrategy() - { - // empty - } + SimpleStrategy() {} uint32_t indexOf(const IFlushTarget::SP &target) const @@ -449,6 +456,14 @@ struct Fixture { } + void putFlushHandler(const vespalib::string &docTypeName, IFlushHandler::SP handler) { + engine.putFlushHandler(DocTypeName(docTypeName), handler); + } + + void addTargetToStrategy(IFlushTarget::SP target) { + strategy->_targets.push_back(std::move(target)); + } + std::shared_ptr<SimpleHandler> addSimpleHandler(Targets targets) { @@ -471,21 +486,17 @@ struct Fixture } }; - TEST_F("require that strategy controls flush target", Fixture(1, IINTERVAL)) { vespalib::Gate fooG, barG; std::vector<vespalib::string> order; - FlushTask::UP fooT(new AppendTask("foo", order, fooG)); - FlushTask::UP barT(new AppendTask("bar", order, barG)); - SimpleTarget::SP foo(new SimpleTarget(std::move(fooT), "foo")); - SimpleTarget::SP bar(new SimpleTarget(std::move(barT), "bar")); - f.strategy->_targets.push_back(foo); - f.strategy->_targets.push_back(bar); - - SimpleHandler::SP handler(new SimpleHandler({bar, foo})); - DocTypeName dtnvanon("anon"); - f.engine.putFlushHandler(dtnvanon, handler); + auto foo = std::make_shared<SimpleTarget>(std::make_unique<AppendTask>("foo", order, fooG), "foo"); + auto bar = std::make_shared<SimpleTarget>(std::make_unique<AppendTask>("bar", order, barG), "bar"); + f.addTargetToStrategy(foo); + f.addTargetToStrategy(bar); + + auto handler = std::make_shared<SimpleHandler>(Targets({bar, foo}), "anon"); + f.putFlushHandler("anon", handler); f.engine.start(); EXPECT_TRUE(fooG.await(LONG_TIMEOUT)); @@ -502,25 +513,20 @@ TEST_F("require that zero handlers does not core", Fixture(2, 50)) TEST_F("require that zero targets does not core", Fixture(2, 50)) { - DocTypeName dtnvfoo("foo"); - DocTypeName dtnvbar("bar"); - f.engine.putFlushHandler(dtnvfoo, - IFlushHandler::SP(new SimpleHandler({}, "foo"))); - f.engine.putFlushHandler(dtnvbar, - IFlushHandler::SP(new SimpleHandler({}, "bar"))); + f.putFlushHandler("foo", std::make_shared<SimpleHandler>(Targets(), "foo")); + f.putFlushHandler("bar", std::make_shared<SimpleHandler>(Targets(), "bar")); f.engine.start(); } TEST_F("require that oldest serial is found", Fixture(1, IINTERVAL)) { - SimpleTarget::SP foo(new SimpleTarget("foo", 10)); - SimpleTarget::SP bar(new SimpleTarget("bar", 20)); - f.strategy->_targets.push_back(foo); - f.strategy->_targets.push_back(bar); - - SimpleHandler::SP handler(new SimpleHandler({foo, bar}, "anon", 25)); - DocTypeName dtnvanon("anon"); - f.engine.putFlushHandler(dtnvanon, handler); + auto foo = std::make_shared<SimpleTarget>("foo", 10); + auto bar = std::make_shared<SimpleTarget>("bar", 20); + f.addTargetToStrategy(foo); + f.addTargetToStrategy(bar); + + auto handler = std::make_shared<SimpleHandler>(Targets({foo, bar}), "anon", 25); + f.putFlushHandler("anon", handler); f.engine.start(); EXPECT_TRUE(handler->_done.await(LONG_TIMEOUT)); @@ -529,24 +535,44 @@ TEST_F("require that oldest serial is found", Fixture(1, IINTERVAL)) EXPECT_EQUAL(FlushDoneHistory({ 10, 20, 25 }), handlerFlushDoneHistory); } +TEST_F("require that GC targets are not considered when oldest serial is found", Fixture(1, IINTERVAL)) +{ + auto foo = std::make_shared<SimpleTarget>("foo", 5); + auto bar = std::make_shared<GCTarget>("bar", 10); + auto baz = std::make_shared<SimpleTarget>("baz", 20); + f.addTargetToStrategy(foo); + f.addTargetToStrategy(bar); + f.addTargetToStrategy(baz); + + auto handler = std::make_shared<SimpleHandler>(Targets({foo, bar, baz}), "handler", 25); + f.putFlushHandler("handler", handler); + f.engine.start(); + + // The targets are flushed in sequence: 'foo', 'bar', 'baz' + EXPECT_TRUE(handler->_done.await(LONG_TIMEOUT)); + EXPECT_EQUAL(25ul, handler->_oldestSerial); + + // Before anything is flushed the oldest serial is 5. + // After 'foo' has been flushed the oldest serial is 20 as GC target 'bar' is not considered. + EXPECT_EQUAL(FlushDoneHistory({ 5, 20, 20, 25 }), handler->getFlushDoneHistory()); +} + TEST_F("require that oldest serial is found in group", Fixture(2, IINTERVAL)) { - SimpleTarget::SP fooT1(new SimpleTarget("fooT1", 10)); - SimpleTarget::SP fooT2(new SimpleTarget("fooT2", 20)); - SimpleTarget::SP barT1(new SimpleTarget("barT1", 5)); - SimpleTarget::SP barT2(new SimpleTarget("barT2", 15)); - f.strategy->_targets.push_back(fooT1); - f.strategy->_targets.push_back(fooT2); - f.strategy->_targets.push_back(barT1); - f.strategy->_targets.push_back(barT2); - - SimpleHandler::SP fooH(new SimpleHandler({fooT1, fooT2}, "fooH", 25)); - DocTypeName dtnvfoo("foo"); - f.engine.putFlushHandler(dtnvfoo, fooH); - - SimpleHandler::SP barH(new SimpleHandler({barT1, barT2}, "barH", 20)); - DocTypeName dtnvbar("bar"); - f.engine.putFlushHandler(dtnvbar, barH); + auto fooT1 = std::make_shared<SimpleTarget>("fooT1", 10); + auto fooT2 = std::make_shared<SimpleTarget>("fooT2", 20); + auto barT1 = std::make_shared<SimpleTarget>("barT1", 5); + auto barT2 = std::make_shared<SimpleTarget>("barT2", 15); + f.addTargetToStrategy(fooT1); + f.addTargetToStrategy(fooT2); + f.addTargetToStrategy(barT1); + f.addTargetToStrategy(barT2); + + auto fooH = std::make_shared<SimpleHandler>(Targets({fooT1, fooT2}), "fooH", 25); + f.putFlushHandler("foo", fooH); + + auto barH = std::make_shared<SimpleHandler>(Targets({barT1, barT2}), "barH", 20); + f.putFlushHandler("bar", barH); f.engine.start(); @@ -574,11 +600,10 @@ TEST_F("require that oldest serial is found in group", Fixture(2, IINTERVAL)) TEST_F("require that target can refuse flush", Fixture(2, IINTERVAL)) { - SimpleTarget::SP target(new SimpleTarget()); - SimpleHandler::SP handler(new SimpleHandler({target})); + auto target = std::make_shared<SimpleTarget>(); + auto handler = std::make_shared<SimpleHandler>(Targets({target})); target->_task = searchcorespi::FlushTask::UP(); - DocTypeName dtnvanon("anon"); - f.engine.putFlushHandler(dtnvanon, handler); + f.putFlushHandler("anon", handler); f.engine.start(); EXPECT_TRUE(target->_initDone.await(LONG_TIMEOUT)); @@ -589,10 +614,9 @@ TEST_F("require that target can refuse flush", Fixture(2, IINTERVAL)) TEST_F("require that targets are flushed when nothing new to flush", Fixture(2, IINTERVAL)) { - SimpleTarget::SP target(new SimpleTarget("anon", 5)); // oldest unflushed serial num = 5 - SimpleHandler::SP handler(new SimpleHandler({target}, "anon", 4)); // current serial num = 4 - DocTypeName dtnvanon("anon"); - f.engine.putFlushHandler(dtnvanon, handler); + auto target = std::make_shared<SimpleTarget>("anon", 5); // oldest unflushed serial num = 5 + auto handler = std::make_shared<SimpleHandler>(Targets({target}), "anon", 4); // current serial num = 4 + f.putFlushHandler("anon", handler); f.engine.start(); EXPECT_TRUE(target->_initDone.await(LONG_TIMEOUT)); @@ -602,14 +626,13 @@ TEST_F("require that targets are flushed when nothing new to flush", TEST_F("require that flushing targets are skipped", Fixture(2, IINTERVAL)) { - SimpleTarget::SP foo(new SimpleTarget("foo")); - SimpleTarget::SP bar(new SimpleTarget("bar")); - f.strategy->_targets.push_back(foo); - f.strategy->_targets.push_back(bar); - - SimpleHandler::SP handler(new SimpleHandler({bar, foo})); - DocTypeName dtnvanon("anon"); - f.engine.putFlushHandler(dtnvanon, handler); + auto foo = std::make_shared<SimpleTarget>("foo"); + auto bar = std::make_shared<SimpleTarget>("bar"); + f.addTargetToStrategy(foo); + f.addTargetToStrategy(bar); + + auto handler = std::make_shared<SimpleHandler>(Targets({bar, foo})); + f.putFlushHandler("anon", handler); f.engine.start(); EXPECT_TRUE(foo->_taskDone.await(LONG_TIMEOUT)); @@ -618,12 +641,11 @@ TEST_F("require that flushing targets are skipped", Fixture(2, IINTERVAL)) TEST_F("require that updated targets are not skipped", Fixture(2, IINTERVAL)) { - SimpleTarget::SP target(new SimpleTarget("target", 1)); - f.strategy->_targets.push_back(target); + auto target = std::make_shared<SimpleTarget>("target", 1); + f.addTargetToStrategy(target); - SimpleHandler::SP handler(new SimpleHandler({target}, "handler", 0)); - DocTypeName dtnvhandler("handler"); - f.engine.putFlushHandler(dtnvhandler, handler); + auto handler = std::make_shared<SimpleHandler>(Targets({target}), "handler", 0); + f.putFlushHandler("handler", handler); f.engine.start(); EXPECT_TRUE(target->_taskDone.await(LONG_TIMEOUT)); @@ -633,8 +655,7 @@ TEST("require that threaded target works") { SimpleExecutor executor; SimpleGetSerialNum getSerialNum; - IFlushTarget::SP target(new SimpleTarget()); - target.reset(new ThreadedFlushTarget(executor, getSerialNum, target)); + auto target = std::make_shared<ThreadedFlushTarget>(executor, getSerialNum, std::make_shared<SimpleTarget>()); EXPECT_FALSE(executor._done.await(SHORT_TIMEOUT)); EXPECT_TRUE(target->initFlush(0).get() != NULL); @@ -643,8 +664,7 @@ TEST("require that threaded target works") TEST("require that cached target works") { - IFlushTarget::SP target(new AssertedTarget()); - target.reset(new CachedFlushTarget(target)); + auto target = std::make_shared<CachedFlushTarget>(std::make_shared<AssertedTarget>()); for (uint32_t i = 0; i < 2; ++i) { EXPECT_EQUAL(0l, target->getApproxMemoryGain().getBefore()); EXPECT_EQUAL(0l, target->getApproxMemoryGain().getAfter()); @@ -654,12 +674,11 @@ TEST("require that cached target works") TEST_F("require that trigger flush works", Fixture(2, IINTERVAL)) { - SimpleTarget::SP target(new SimpleTarget("target", 1)); - f.strategy->_targets.push_back(target); + auto target = std::make_shared<SimpleTarget>("target", 1); + f.addTargetToStrategy(target); - SimpleHandler::SP handler(new SimpleHandler({target}, "handler", 9)); - DocTypeName dtnvhandler("handler"); - f.engine.putFlushHandler(dtnvhandler, handler); + auto handler = std::make_shared<SimpleHandler>(Targets({target}), "handler", 9); + f.putFlushHandler("handler", handler); f.engine.start(); f.engine.triggerFlush(); EXPECT_TRUE(target->_initDone.await(LONG_TIMEOUT)); @@ -693,13 +712,13 @@ assertThatHandlersInCurrentSet(FlushEngine & engine, const std::vector<const cha TEST_F("require that concurrency works", Fixture(2, 1)) { - SimpleTarget::SP target1(new SimpleTarget("target1", 1, false)); - SimpleTarget::SP target2(new SimpleTarget("target2", 2, false)); - SimpleTarget::SP target3(new SimpleTarget("target3", 3, false)); - SimpleHandler::SP handler(new SimpleHandler({target1, target2, target3}, "handler", 9)); - DocTypeName dtnvhandler("handler"); - f.engine.putFlushHandler(dtnvhandler, handler); + auto target1 = std::make_shared<SimpleTarget>("target1", 1, false); + auto target2 = std::make_shared<SimpleTarget>("target2", 2, false); + auto target3 = std::make_shared<SimpleTarget>("target3", 3, false); + auto handler = std::make_shared<SimpleHandler>(Targets({target1, target2, target3}), "handler", 9); + f.putFlushHandler("handler", handler); f.engine.start(); + EXPECT_TRUE(target1->_initDone.await(LONG_TIMEOUT)); EXPECT_TRUE(target2->_initDone.await(LONG_TIMEOUT)); EXPECT_TRUE(!target3->_initDone.await(SHORT_TIMEOUT)); @@ -714,11 +733,11 @@ TEST_F("require that concurrency works", Fixture(2, 1)) TEST_F("require that state explorer can list flush targets", Fixture(1, 1)) { - SimpleTarget::SP target = std::make_shared<SimpleTarget>("target1", 100, false); - f.engine.putFlushHandler(DocTypeName("handler"), - std::make_shared<SimpleHandler>( - Targets({target, std::make_shared<SimpleTarget>("target2", 50, true)}), - "handler", 9)); + auto target = std::make_shared<SimpleTarget>("target1", 100, false); + f.putFlushHandler("handler", + std::make_shared<SimpleHandler>( + Targets({target, std::make_shared<SimpleTarget>("target2", 50, true)}), + "handler", 9)); f.engine.start(); target->_initDone.await(LONG_TIMEOUT); target->_taskStart.await(LONG_TIMEOUT); diff --git a/searchcore/src/tests/proton/matching/matching_test.cpp b/searchcore/src/tests/proton/matching/matching_test.cpp index de6a452baf3..7c6779fdc63 100644 --- a/searchcore/src/tests/proton/matching/matching_test.cpp +++ b/searchcore/src/tests/proton/matching/matching_test.cpp @@ -53,6 +53,7 @@ using namespace search; using search::attribute::test::MockAttributeContext; using search::index::schema::DataType; using storage::spi::Timestamp; +using search::fef::indexproperties::hitcollector::HeapSize; void inject_match_phase_limiting(Properties &setup, const vespalib::string &attribute, size_t max_hits, bool descending) { @@ -287,7 +288,7 @@ struct MyWorld { Matcher::SP matcher = createMatcher(); search::fef::Properties overrides; auto mtf = matcher->create_match_tools_factory(*req, searchContext, attributeContext, metaStore, overrides); - auto diversity = mtf->createDiversifier(); + auto diversity = mtf->createDiversifier(HeapSize::lookup(config)); EXPECT_EQUAL(expectDiverse, static_cast<bool>(diversity)); } diff --git a/searchcore/src/vespa/searchcore/fdispatch/common/rpc.cpp b/searchcore/src/vespa/searchcore/fdispatch/common/rpc.cpp index 0d2f6dff983..eaff3b90d78 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/common/rpc.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/common/rpc.cpp @@ -43,12 +43,12 @@ FastS_RPC::Init(int port, const vespalib::string &myHeartbeatId) void FastS_RPC::RegisterMethods(FRT_ReflectionBuilder *rb) { - rb->DefineMethod("fs.admin.getNodeType", "", "s", true, + rb->DefineMethod("fs.admin.getNodeType", "", "s", FRT_METHOD(FastS_RPC::RPC_GetNodeType), this); rb->MethodDesc("Get string indicating the node type"); rb->ReturnDesc("type", "node type"); //---------------------------------------------------------------// - rb->DefineMethod("fs.admin.getCompileInfo", "", "*", true, + rb->DefineMethod("fs.admin.getCompileInfo", "", "*", FRT_METHOD(FastS_RPC::RPC_GetCompileInfo), this); rb->MethodDesc("Obtain compile info for this node"); rb->ReturnDesc("info", "any number of descriptive strings"); diff --git a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp index 388d817e596..b85e706397d 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp @@ -8,6 +8,7 @@ #include <vespa/searchcore/util/eventloop.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/config/helper/configgetter.hpp> +#include <vespa/vespalib/net/crypto_engine.h> #include <vespa/log/log.h> LOG_SETUP(".fdispatch"); @@ -296,7 +297,7 @@ Fdispatch::Init() LOG(debug, "Creating FNET transport"); - _transport = std::make_unique<FNET_Transport>(_config->transportthreads); + _transport = std::make_unique<FNET_Transport>(std::make_shared<vespalib::NullCryptoEngine>(), _config->transportthreads); // disable encryption // grab node slowness limit defaults diff --git a/searchcore/src/vespa/searchcore/fdispatch/program/rpc.cpp b/searchcore/src/vespa/searchcore/fdispatch/program/rpc.cpp index 4217ef6d8c9..56301c5e986 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/program/rpc.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/program/rpc.cpp @@ -9,13 +9,13 @@ FastS_fdispatch_RPC::RegisterMethods(FRT_ReflectionBuilder *rb) { FastS_RPC::RegisterMethods(rb); //------------------------------------------------------------------ - rb->DefineMethod("fs.admin.enableEngine", "s", "i", true, + rb->DefineMethod("fs.admin.enableEngine", "s", "i", FRT_METHOD(FastS_fdispatch_RPC::RPC_EnableEngine), this); rb->MethodDesc("Enable the given engine (clear badness)."); rb->ParamDesc("name", "engine name"); rb->ReturnDesc("count", "number of engines affected"); //------------------------------------------------------------------ - rb->DefineMethod("fs.admin.disableEngine", "s", "i", true, + rb->DefineMethod("fs.admin.disableEngine", "s", "i", FRT_METHOD(FastS_fdispatch_RPC::RPC_DisableEngine), this); rb->MethodDesc("Disable the given engine (mark as admin bad)."); rb->ParamDesc("name", "engine name"); diff --git a/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp b/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp index f2215fff978..93153a920cf 100644 --- a/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp +++ b/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp @@ -2,6 +2,7 @@ #include "groupingcontext.h" #include <vespa/searchlib/aggregation/predicates.h> +#include <vespa/searchlib/aggregation/modifiers.h> namespace search { @@ -21,6 +22,8 @@ GroupingContext::deserialize(const char *groupSpec, uint32_t groupSpecLen) for (size_t i = 0; i < numGroupings; i++) { GroupingPtr grouping(new search::aggregation::Grouping); grouping->deserialize(nis); + aggregation::Attribute2AttributeKeyed attr2AttrKeyed; + grouping->select(attr2AttrKeyed, attr2AttrKeyed); grouping->setClock(&_clock); grouping->setTimeOfDoom(_timeOfDoom); _groupingList.push_back(grouping); diff --git a/searchcore/src/vespa/searchcore/proton/common/eventlogger.cpp b/searchcore/src/vespa/searchcore/proton/common/eventlogger.cpp index 78f73742fed..5966589d635 100644 --- a/searchcore/src/vespa/searchcore/proton/common/eventlogger.cpp +++ b/searchcore/src/vespa/searchcore/proton/common/eventlogger.cpp @@ -109,13 +109,17 @@ EventLogger::flushStart(const string &name, int64_t beforeMemory, int64_t afterM } void -EventLogger::flushComplete(const string &name, int64_t elapsedTimeMs, +EventLogger::flushComplete(const string &name, int64_t elapsedTimeMs, SerialNum flushed, const string &outputPath, size_t outputPathElems) { JSONStringer jstr; jstr.beginObject(); jstr.appendKey("name").appendString(name); jstr.appendKey("time.elapsed.ms").appendInt64(elapsedTimeMs); + jstr.appendKey("serialnum") + .beginObject() + .appendKey("flushed").appendInt64(flushed) + .endObject(); if (!outputPath.empty()) { jstr.appendKey("output"); LogUtil::logDir(jstr, outputPath, outputPathElems); @@ -124,6 +128,20 @@ EventLogger::flushComplete(const string &name, int64_t elapsedTimeMs, EV_STATE("flush.complete", jstr.toString().data()); } +void +EventLogger::flushPrune(const string &name, SerialNum oldestFlushed) +{ + JSONStringer jstr; + jstr.beginObject(); + jstr.appendKey("name").appendString(name); + jstr.appendKey("serialnum") + .beginObject() + .appendKey("oldestflushed").appendInt64(oldestFlushed) + .endObject(); + jstr.endObject(); + EV_STATE("flush.prune", jstr.toString().data()); +} + namespace { void diff --git a/searchcore/src/vespa/searchcore/proton/common/eventlogger.h b/searchcore/src/vespa/searchcore/proton/common/eventlogger.h index 6ba8852496e..574e650732a 100644 --- a/searchcore/src/vespa/searchcore/proton/common/eventlogger.h +++ b/searchcore/src/vespa/searchcore/proton/common/eventlogger.h @@ -41,8 +41,10 @@ public: SerialNum current); static void flushComplete(const string &name, int64_t elapsedTimeMs, + SerialNum flushed, const string &outputPath, size_t outputPathElems); + static void flushPrune(const string &name, SerialNum oldestFlushed); static void loadAttributeStart(const vespalib::string &subDbName, const vespalib::string &attrName); static void loadAttributeComplete(const vespalib::string &subDbName, const vespalib::string &attrName, int64_t elapsedTimeMs); diff --git a/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp b/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp index 0d2c556b4d6..f7e0b7981bb 100644 --- a/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp +++ b/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp @@ -22,15 +22,23 @@ namespace proton { namespace { -search::SerialNum -findOldestFlushedSerial(const IFlushTarget::List &lst, const IFlushHandler &handler) +std::pair<search::SerialNum, vespalib::string> +findOldestFlushedTarget(const IFlushTarget::List &lst, const IFlushHandler &handler) { - search::SerialNum ret(handler.getCurrentSerialNumber()); - for (const IFlushTarget::SP & target : lst) { - ret = std::min(ret, target->getFlushedSerialNum()); + search::SerialNum oldestFlushedSerial = handler.getCurrentSerialNumber(); + vespalib::string oldestFlushedName = "null"; + for (const IFlushTarget::SP &target : lst) { + if (target->getType() != IFlushTarget::Type::GC) { + search::SerialNum targetFlushedSerial = target->getFlushedSerialNum(); + if (targetFlushedSerial <= oldestFlushedSerial) { + oldestFlushedSerial = targetFlushedSerial; + oldestFlushedName = target->getName(); + } + } } - LOG(debug, "Oldest flushed serial for '%s' is %" PRIu64 ".", handler.getName().c_str(), ret); - return ret; + LOG(debug, "Oldest flushed serial for handler='%s', target='%s': %" PRIu64 ".", + handler.getName().c_str(), oldestFlushedName.c_str(), oldestFlushedSerial); + return std::make_pair(oldestFlushedSerial, oldestFlushedName); } void @@ -174,6 +182,16 @@ FlushEngine::Run(FastOS_ThreadInterface *, void *) prune(); } +namespace { + +vespalib::string +createName(const IFlushHandler &handler, const vespalib::string &targetName) +{ + return (handler.getName() + "." + targetName); +} + +} + bool FlushEngine::prune() { @@ -187,7 +205,11 @@ FlushEngine::prune() } for (const auto &handler : toPrune) { IFlushTarget::List lst = handler->getFlushTargets(); - handler->flushDone(findOldestFlushedSerial(lst, *handler)); + auto oldestFlushed = findOldestFlushedTarget(lst, *handler); + if (LOG_WOULD_LOG(event)) { + EventLogger::flushPrune(createName(*handler, oldestFlushed.second), oldestFlushed.first); + } + handler->flushDone(oldestFlushed.first); } return true; } @@ -333,7 +355,8 @@ FlushEngine::flushDone(const FlushContext &ctx, uint32_t taskId) } if (LOG_WOULD_LOG(event)) { FlushStats stats = ctx.getTarget()->getLastFlushStats(); - EventLogger::flushComplete(ctx.getName(), duration.ms(), stats.getPath(), stats.getPathElementsToLog()); + EventLogger::flushComplete(ctx.getName(), duration.ms(), ctx.getTarget()->getFlushedSerialNum(), + stats.getPath(), stats.getPathElementsToLog()); } LOG(debug, "FlushEngine::flushDone(taskId='%d') took '%f' secs", taskId, duration.sec()); std::lock_guard<std::mutex> guard(_lock); diff --git a/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.cpp b/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.cpp index 52c249fe13b..4f14f709d29 100644 --- a/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.cpp +++ b/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.cpp @@ -10,11 +10,7 @@ InitializerTask::InitializerTask() { } - -InitializerTask::~InitializerTask() -{ -} - +InitializerTask::~InitializerTask() = default; void InitializerTask::addDependency(SP dependency) @@ -22,5 +18,4 @@ InitializerTask::addDependency(SP dependency) _dependencies.emplace_back(std::move(dependency)); } -} // namespace proton::initializer - +} diff --git a/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.h b/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.h index b84db9d6402..ecf98b86fc4 100644 --- a/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.h +++ b/searchcore/src/vespa/searchcore/proton/initializer/initializer_task.h @@ -4,9 +4,7 @@ #include <memory> #include <vector> -namespace proton { - -namespace initializer { +namespace proton::initializer { /* * Class representign an initializer task, used to load a data @@ -35,6 +33,4 @@ public: virtual void run() = 0; }; -} // namespace proton::initializer - -} // namespace proton +} diff --git a/searchcore/src/vespa/searchcore/proton/initializer/task_runner.cpp b/searchcore/src/vespa/searchcore/proton/initializer/task_runner.cpp index 770f00dc264..86c2b525113 100644 --- a/searchcore/src/vespa/searchcore/proton/initializer/task_runner.cpp +++ b/searchcore/src/vespa/searchcore/proton/initializer/task_runner.cpp @@ -92,8 +92,7 @@ TaskRunner::runTask(InitializerTask::SP task) vespalib::ThreadStackExecutor executor(1, 128 * 1024); std::promise<void> promise; auto future = promise.get_future(); - runTask(task, executor, - makeLambdaTask([&]() { promise.set_value(); })); + runTask(task, executor, makeLambdaTask([&]() { promise.set_value(); })); future.wait(); } @@ -119,8 +118,7 @@ TaskRunner::runTask(InitializerTask::SP rootTask, vespalib::Executor &contextExecutor, vespalib::Executor::Task::UP doneTask) { - Context::SP context(std::make_shared<Context>(rootTask, contextExecutor, - std::move(doneTask))); + auto context(std::make_shared<Context>(rootTask, contextExecutor, std::move(doneTask))); context->execute(makeLambdaTask([=]() { pollTask(context); } )); } diff --git a/searchcore/src/vespa/searchcore/proton/initializer/task_runner.h b/searchcore/src/vespa/searchcore/proton/initializer/task_runner.h index 3b52936917c..f28c46334bc 100644 --- a/searchcore/src/vespa/searchcore/proton/initializer/task_runner.h +++ b/searchcore/src/vespa/searchcore/proton/initializer/task_runner.h @@ -6,9 +6,7 @@ #include <vespa/vespalib/stllike/hash_set.h> #include <cassert> -namespace proton { - -namespace initializer { +namespace proton::initializer { /* * Class to run multiple init tasks with dependent tasks. @@ -46,20 +44,15 @@ class TaskRunner { void schedulePoll(); }; void getReadyTasks(const InitializerTask::SP task, TaskList &readyTasks, TaskSet &checked); - void setTaskRunning(InitializerTask &task); - void setTaskDone(InitializerTask &task, Context::SP context); - void internalRunTask(InitializerTask::SP task, Context::SP context); - void internalRunTasks(const TaskList &taskList, Context::SP context); - void pollTask(Context::SP context); public: TaskRunner(vespalib::Executor &executor); - virtual ~TaskRunner(); + ~TaskRunner(); // Depecreated blocking API void runTask(InitializerTask::SP task); @@ -70,6 +63,4 @@ public: vespalib::Executor::Task::UP doneTask); }; -} // namespace proton::initializer - -} // namespace proton +} diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_master.cpp b/searchcore/src/vespa/searchcore/proton/matching/match_master.cpp index d974be1ce3a..4d49e9b5d1b 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_master.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/match_master.cpp @@ -64,7 +64,7 @@ MatchMaster::match(const MatchParams ¶ms, fastos::StopWatch query_latency_time; query_latency_time.start(); vespalib::DualMergeDirector mergeDirector(threadBundle.size()); - MatchLoopCommunicator communicator(threadBundle.size(), params.heapSize, mtf.createDiversifier()); + MatchLoopCommunicator communicator(threadBundle.size(), params.heapSize, mtf.createDiversifier(params.heapSize)); TimedMatchLoopCommunicator timedCommunicator(communicator); DocidRangeScheduler::UP scheduler = createScheduler(threadBundle.size(), numSearchPartitions, params.numDocs); diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp b/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp index a00a90d7a10..28d56b7e0a2 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp @@ -204,7 +204,7 @@ MatchToolsFactory::createMatchTools() const } std::unique_ptr<IDiversifier> -MatchToolsFactory::createDiversifier() const +MatchToolsFactory::createDiversifier(uint32_t heapSize) const { if ( !_diversityParams.enabled() ) { return std::unique_ptr<IDiversifier>(); @@ -214,8 +214,8 @@ MatchToolsFactory::createDiversifier() const LOG(warning, "Skipping diversity due to no %s attribute.", _diversityParams.attribute.c_str()); return std::unique_ptr<IDiversifier>(); } - size_t max_per_group = _rankSetup.getHeapSize()/_diversityParams.min_groups; - return DiversityFilter::create(*attr, _rankSetup.getHeapSize(), max_per_group, _diversityParams.min_groups, + size_t max_per_group = heapSize/_diversityParams.min_groups; + return DiversityFilter::create(*attr, heapSize, max_per_group, _diversityParams.min_groups, _diversityParams.cutoff_strategy == DiversityParams::CutoffStrategy::STRICT); } diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_tools.h b/searchcore/src/vespa/searchcore/proton/matching/match_tools.h index 8f04eebc50e..0ecf6eb5b78 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_tools.h +++ b/searchcore/src/vespa/searchcore/proton/matching/match_tools.h @@ -124,7 +124,7 @@ public: const MaybeMatchPhaseLimiter &match_limiter() const { return *_match_limiter; } MatchTools::UP createMatchTools() const; bool should_diversify() const { return _diversityParams.enabled(); } - std::unique_ptr<search::queryeval::IDiversifier> createDiversifier() const; + std::unique_ptr<search::queryeval::IDiversifier> createDiversifier(uint32_t heapSize) const; search::queryeval::Blueprint::HitEstimate estimate() const { return _query.estimate(); } bool has_first_phase_rank() const { return !_rankSetup.getFirstPhaseRank().empty(); } std::unique_ptr<AttributeOperationTask> createOnMatchTask() const; diff --git a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp index be0a720f1c1..b32af7e3e5a 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp @@ -28,6 +28,7 @@ using search::FeatureSet; using search::attribute::IAttributeContext; using search::fef::MatchDataLayout; using search::fef::MatchData; +using search::fef::indexproperties::hitcollector::HeapSize; using search::queryeval::Blueprint; using search::queryeval::SearchIterator; using vespalib::Doom; @@ -242,14 +243,16 @@ Matcher::match(const SearchRequest &request, vespalib::ThreadBundle &threadBundl return reply; } - MatchParams params(searchContext.getDocIdLimit(), _rankSetup->getHeapSize(), _rankSetup->getArraySize(), + const Properties & rankProperties = request.propertiesMap.rankProperties(); + uint32_t heapSize = HeapSize::lookup(rankProperties, _rankSetup->getHeapSize()); + + MatchParams params(searchContext.getDocIdLimit(), heapSize, _rankSetup->getArraySize(), _rankSetup->getRankScoreDropLimit(), request.offset, request.maxhits, !_rankSetup->getSecondPhaseRank().empty(), !willNotNeedRanking(request, groupingContext)); ResultProcessor rp(attrContext, metaStore, sessionMgr, groupingContext, sessionId, request.sortSpec, params.offset, params.hits, request.should_drop_sort_data()); - const Properties & rankProperties = request.propertiesMap.rankProperties(); size_t numThreadsPerSearch = computeNumThreadsPerSearch(mtf->estimate(), rankProperties); LimitedThreadBundleWrapper limitedThreadBundle(threadBundle, numThreadsPerSearch); MatchMaster master; diff --git a/searchcore/src/vespa/searchcore/proton/server/ddbstate.cpp b/searchcore/src/vespa/searchcore/proton/server/ddbstate.cpp index 09b81f373df..73fff1cfd42 100644 --- a/searchcore/src/vespa/searchcore/proton/server/ddbstate.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/ddbstate.cpp @@ -33,10 +33,7 @@ DDBState::DDBState() } -DDBState::~DDBState() -{ - -} +DDBState::~DDBState() = default; bool diff --git a/searchcore/src/vespa/searchcore/proton/server/rpc_hooks.cpp b/searchcore/src/vespa/searchcore/proton/server/rpc_hooks.cpp index ab012760762..6e442f472b1 100644 --- a/searchcore/src/vespa/searchcore/proton/server/rpc_hooks.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/rpc_hooks.cpp @@ -127,7 +127,7 @@ RPCHooksBase::initRPC() FRT_ReflectionBuilder rb(_orb.get()); //------------------------------------------------------------------------- - rb.DefineMethod("pandora.rtc.getState", "ii", "SSi", true, + rb.DefineMethod("pandora.rtc.getState", "ii", "SSi", FRT_METHOD(RPCHooksBase::rpc_GetState), this); rb.MethodDesc("Get the current state of node"); rb.ParamDesc("gencnt", "old state generation held by the client"); @@ -136,7 +136,7 @@ RPCHooksBase::initRPC() rb.ReturnDesc("values", "Array of state values"); rb.ReturnDesc("newgen", "New state generation count"); //------------------------------------------------------------------------- - rb.DefineMethod("proton.getStatus", "s", "SSSS", true, + rb.DefineMethod("proton.getStatus", "s", "SSSS", FRT_METHOD(RPCHooksBase::rpc_GetProtonStatus), this); rb.MethodDesc("Get the current state of proton or a proton component"); rb.ParamDesc("component", "Which component to check the status for"); @@ -145,7 +145,7 @@ RPCHooksBase::initRPC() rb.ReturnDesc("internalStates", "Array of internal states "); rb.ReturnDesc("message", "Array of status messages"); //------------------------------------------------------------------------- - rb.DefineMethod("pandora.rtc.getIncrementalState", "i", "SSi", true, + rb.DefineMethod("pandora.rtc.getIncrementalState", "i", "SSi", FRT_METHOD(RPCHooksBase::rpc_getIncrementalState), this); rb.MethodDesc("Get node state changes since last invocation"); rb.ParamDesc("timeout", "How many milliseconds to wait for state update"); @@ -153,26 +153,26 @@ RPCHooksBase::initRPC() rb.ReturnDesc("values", "Array of state values"); rb.ReturnDesc("dummy", "Dummy value to enable code reuse"); //------------------------------------------------------------------------- - rb.DefineMethod("pandora.rtc.shutdown", "", "", true, + rb.DefineMethod("pandora.rtc.shutdown", "", "", FRT_METHOD(RPCHooksBase::rpc_Shutdown), this); rb.MethodDesc("Shut down the rtc application"); //------------------------------------------------------------------------- - rb.DefineMethod("pandora.rtc.die", "", "", true, + rb.DefineMethod("pandora.rtc.die", "", "", FRT_METHOD(RPCHooksBase::rpc_die), this); rb.MethodDesc("Exit the rtc application without cleanup"); //------------------------------------------------------------------------- - rb.DefineMethod("proton.triggerFlush", "", "b", true, + rb.DefineMethod("proton.triggerFlush", "", "b", FRT_METHOD(RPCHooksBase::rpc_triggerFlush), this); rb.MethodDesc("Tell the node to trigger flush ASAP"); rb.ReturnDesc("success", "Whether or not a flush was triggered."); //------------------------------------------------------------------------- - rb.DefineMethod("proton.prepareRestart", "", "b", true, + rb.DefineMethod("proton.prepareRestart", "", "b", FRT_METHOD(RPCHooksBase::rpc_prepareRestart), this); rb.MethodDesc("Tell the node to prepare for a restart by flushing components " "such that TLS replay time + time spent flushing components is as low as possible"); rb.ReturnDesc("success", "Whether or not prepare for restart was triggered."); //------------------------------------------------------------------------- - rb.DefineMethod("proton.getDocsums", "bix", "bix", true, FRT_METHOD(RPCHooksBase::rpc_getDocSums), this); + rb.DefineMethod("proton.getDocsums", "bix", "bix", FRT_METHOD(RPCHooksBase::rpc_getDocSums), this); rb.MethodDesc("Get list of document summaries"); rb.ParamDesc("encoding", "0=raw, 6=lz4"); rb.ParamDesc("uncompressedBlobSize", "Uncompressed blob size"); diff --git a/searchcore/src/vespa/searchcore/proton/server/transactionlogmanager.cpp b/searchcore/src/vespa/searchcore/proton/server/transactionlogmanager.cpp index 4bece3e6860..72fcf812ebc 100644 --- a/searchcore/src/vespa/searchcore/proton/server/transactionlogmanager.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/transactionlogmanager.cpp @@ -31,8 +31,7 @@ TransactionLogManager::TransactionLogManager(const vespalib::string &tlsSpec, { } -TransactionLogManager::~TransactionLogManager() { -} +TransactionLogManager::~TransactionLogManager() = default; void TransactionLogManager::init(SerialNum oldestConfigSerial, diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt index 1ad8f562384..c55aadb5eae 100644 --- a/searchlib/CMakeLists.txt +++ b/searchlib/CMakeLists.txt @@ -128,6 +128,7 @@ vespa_define_module( src/tests/engine/monitorapi src/tests/engine/searchapi src/tests/engine/transportserver + src/tests/expression/attributenode src/tests/features src/tests/features/beta src/tests/features/constant diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index f7fe91cb56f..ac5eefcc5b2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -23,6 +23,7 @@ public class ImportedModel { private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); private final String name; + private final String source; private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> arguments = new HashMap<>(); @@ -36,16 +37,21 @@ public class ImportedModel { * Creates a new imported model. * * @param name the name of this mode, containing only characters in [A-Za-z0-9_] + * @param source the source path (directory or file) of this model */ - public ImportedModel(String name) { + public ImportedModel(String name, String source) { if ( ! nameRegexp.matcher(name).matches()) throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'"); this.name = name; + this.source = source; } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ public String name() { return name; } + /** Returns the source path (directiry or file) of this model */ + public String source() { return source; } + /** Returns an immutable map of the arguments ("Placeholders") of this */ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java index 92cb8c3f360..40d1ca8030a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java @@ -6,7 +6,10 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.path.Path; import java.io.File; +import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; /** @@ -30,25 +33,30 @@ public class ImportedModels { } public ImportedModels(File modelsDirectory) { - ImmutableMap.Builder<String, ImportedModel> builder = new ImmutableMap.Builder<>(); + Map<String, ImportedModel> models = new HashMap<>(); // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, builder); - importedModels = builder.build(); + importRecursively(modelsDirectory, models); + importedModels = ImmutableMap.copyOf(models); } - private static void importRecursively(File dir, ImmutableMap.Builder<String, ImportedModel> builder) { + private static void importRecursively(File dir, Map<String, ImportedModel> models) { if ( ! dir.isDirectory()) return; - for (File child : dir.listFiles()) { + + Arrays.stream(dir.listFiles()).sorted().forEach(child -> { Optional<ModelImporter> importer = findImporterOf(child); if (importer.isPresent()) { String name = toName(child); - builder.put(name, importer.get().importModel(name, child)); + ImportedModel existing = models.get(name); + if (existing != null) + throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + + " both resolve to the model name '" + name + "'"); + models.put(name, importer.get().importModel(name, child)); } else { - importRecursively(child, builder); + importRecursively(child, models); } - } + }); } private static Optional<ModelImporter> findImporterOf(File path) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index 13718935cef..9833e52cb61 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -46,8 +46,8 @@ public abstract class ModelImporter { * Takes an IntermediateGraph and converts it to a ImportedModel containing * the actual Vespa ranking expressions. */ - static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) { - ImportedModel model = new ImportedModel(graph.name()); + static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { + ImportedModel model = new ImportedModel(graph.name(), modelSource); graph.optimize(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java index 187e2f2e29d..917b0d6a389 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java @@ -31,7 +31,7 @@ public class OnnxImporter extends ModelImporter { try (FileInputStream inputStream = new FileInputStream(modelPath)) { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph); + return convertIntermediateGraphToModel(graph, modelPath); } catch (IOException e) { throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java index afd01b3d7da..7c18e04bae7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java @@ -39,7 +39,7 @@ public class TensorFlowImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelDir) { try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - return importModel(modelName, model); + return importModel(modelName, modelDir, model); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); @@ -47,10 +47,10 @@ public class TensorFlowImporter extends ModelImporter { } /** Imports a TensorFlow model */ - ImportedModel importModel(String modelName, SavedModelBundle model) { + ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { try { IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph); + return convertIntermediateGraphToModel(graph, modelDir); } catch (IOException e) { throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java index e08214579db..725f319a839 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java @@ -27,7 +27,7 @@ public class XGBoostImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelPath) { try { - ImportedModel model = new ImportedModel(modelName); + ImportedModel model = new ImportedModel(modelName, modelPath); XGBoostParser parser = new XGBoostParser(modelPath); RankingExpression expression = new RankingExpression(parser.toRankingExpression()); model.expression(modelName, expression); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java index 39a8b211d09..eee92862e7f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java @@ -10,8 +10,8 @@ import java.util.Map; import java.util.Set; /** - * Holds an intermediate representation of an imported ONNX or TensorFlow - * graph. After this intermediate representation is constructed, it is used to + * Holds an intermediate representation of an imported model graph. + * After this intermediate representation is constructed, it is used to * simplify and optimize the computational graph and then converted into the * final ImportedModel that holds the Vespa ranking expressions for the model. * diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt new file mode 100644 index 00000000000..eb926836576 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt @@ -0,0 +1,7982 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "AddN" + input_arg { + name: "inputs" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "sum" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + type: DT_VARIANT + } + } + } + is_aggregate: true + is_commutative: true + } + op { + name: "ApplyGradientDescent" + input_arg { + name: "var" + type_attr: "T" + is_ref: true + } + input_arg { + name: "alpha" + type_attr: "T" + } + input_arg { + name: "delta" + type_attr: "T" + } + output_arg { + name: "out" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } + } + op { + name: "Assign" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + attr { + name: "validate_shape" + type: "bool" + default_value { + b: true + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: true + } + } + allows_uninitialized_input: true + } + op { + name: "BroadcastGradientArgs" + input_arg { + name: "s0" + type_attr: "T" + } + input_arg { + name: "s1" + type_attr: "T" + } + output_arg { + name: "r0" + type_attr: "T" + } + output_arg { + name: "r1" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Cast" + input_arg { + name: "x" + type_attr: "SrcT" + } + output_arg { + name: "y" + type_attr: "DstT" + } + attr { + name: "SrcT" + type: "type" + } + attr { + name: "DstT" + type: "type" + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "ExpandDims" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dim" + type_attr: "Tdim" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tdim" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Fill" + input_arg { + name: "dims" + type: DT_INT32 + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "FloorDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "GreaterEqual" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type: DT_BOOL + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_UINT16 + type: DT_HALF + } + } + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "InTopKV2" + input_arg { + name: "predictions" + type: DT_FLOAT + } + input_arg { + name: "targets" + type_attr: "T" + } + input_arg { + name: "k" + type_attr: "T" + } + output_arg { + name: "precision" + type: DT_BOOL + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { + name: "Mean" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + is_stateful: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "PreventGradient" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "message" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Prod" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RealDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Reshape" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "shape" + type_attr: "Tshape" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshape" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "ScalarSummary" + input_arg { + name: "tags" + type: DT_STRING + } + input_arg { + name: "values" + type_attr: "T" + } + output_arg { + name: "summary" + type: DT_STRING + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_UINT16 + type: DT_HALF + } + } + } + } + op { + name: "Select" + input_arg { + name: "condition" + type: DT_BOOL + } + input_arg { + name: "t" + type_attr: "T" + } + input_arg { + name: "e" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "Selu" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "SeluGrad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "outputs" + type_attr: "T" + } + output_arg { + name: "backprops" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "Shape" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "SparseSoftmaxCrossEntropyWithLogits" + input_arg { + name: "features" + type_attr: "T" + } + input_arg { + name: "labels" + type_attr: "Tlabels" + } + output_arg { + name: "loss" + type_attr: "T" + } + output_arg { + name: "backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "Tlabels" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Sum" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Tile" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "multiples" + type_attr: "Tmultiples" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tmultiples" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "TruncatedNormal" + input_arg { + name: "shape" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "seed" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "seed2" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true + } + op { + name: "VariableV2" + output_arg { + name: "ref" + type_attr: "dtype" + is_ref: true + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + } + op { + name: "ZerosLike" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + } + tags: "serve" + tensorflow_version: "1.4.1" + tensorflow_git_version: "v1.4.0-19-ga52c8d9" + } + graph_def { + node { + name: "input" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + node { + name: "y" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "shape" + value { + shape { + unknown_rank: true + } + } + } + } + node { + name: "dnn/hidden1/truncated_normal/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\020\003\000\000,\001\000\000" + } + } + } + } + node { + name: "dnn/hidden1/truncated_normal/mean" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/hidden1/truncated_normal/stddev" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0714285746216774 + } + } + } + } + node { + name: "dnn/hidden1/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "dnn/hidden1/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node { + name: "dnn/hidden1/truncated_normal/mul" + op: "Mul" + input: "dnn/hidden1/truncated_normal/TruncatedNormal" + input: "dnn/hidden1/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/truncated_normal" + op: "Add" + input: "dnn/hidden1/truncated_normal/mul" + input: "dnn/hidden1/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/weights" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/hidden1/weights/Assign" + op: "Assign" + input: "dnn/hidden1/weights" + input: "dnn/hidden1/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/hidden1/weights/read" + op: "Identity" + input: "dnn/hidden1/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 300 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/hidden1/bias" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 300 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/hidden1/bias/Assign" + op: "Assign" + input: "dnn/hidden1/bias" + input: "dnn/hidden1/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/hidden1/bias/read" + op: "Identity" + input: "dnn/hidden1/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/MatMul" + op: "MatMul" + input: "input" + input: "dnn/hidden1/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "dnn/hidden1/add" + op: "Add" + input: "dnn/hidden1/MatMul" + input: "dnn/hidden1/bias/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/mul/x" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.009999999776482582 + } + } + } + } + node { + name: "dnn/hidden1/mul" + op: "Mul" + input: "dnn/hidden1/mul/x" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden1/Maximum" + op: "Maximum" + input: "dnn/hidden1/mul" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "dnn/hidden2/truncated_normal/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: ",\001\000\000d\000\000\000" + } + } + } + } + node { + name: "dnn/hidden2/truncated_normal/mean" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/hidden2/truncated_normal/stddev" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.1154700517654419 + } + } + } + } + node { + name: "dnn/hidden2/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "dnn/hidden2/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node { + name: "dnn/hidden2/truncated_normal/mul" + op: "Mul" + input: "dnn/hidden2/truncated_normal/TruncatedNormal" + input: "dnn/hidden2/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/hidden2/truncated_normal" + op: "Add" + input: "dnn/hidden2/truncated_normal/mul" + input: "dnn/hidden2/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/hidden2/weights" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/hidden2/weights/Assign" + op: "Assign" + input: "dnn/hidden2/weights" + input: "dnn/hidden2/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/hidden2/weights/read" + op: "Identity" + input: "dnn/hidden2/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/hidden2/zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 100 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/hidden2/bias" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/hidden2/bias/Assign" + op: "Assign" + input: "dnn/hidden2/bias" + input: "dnn/hidden2/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/hidden2/bias/read" + op: "Identity" + input: "dnn/hidden2/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/hidden2/MatMul" + op: "MatMul" + input: "dnn/hidden1/Maximum" + input: "dnn/hidden2/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "dnn/hidden2/add" + op: "Add" + input: "dnn/hidden2/MatMul" + input: "dnn/hidden2/bias/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/hidden2/Selu" + op: "Selu" + input: "dnn/hidden2/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "dnn/outputs/truncated_normal/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "d\000\000\000\n\000\000\000" + } + } + } + } + node { + name: "dnn/outputs/truncated_normal/mean" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/outputs/truncated_normal/stddev" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.20000000298023224 + } + } + } + } + node { + name: "dnn/outputs/truncated_normal/TruncatedNormal" + op: "TruncatedNormal" + input: "dnn/outputs/truncated_normal/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node { + name: "dnn/outputs/truncated_normal/mul" + op: "Mul" + input: "dnn/outputs/truncated_normal/TruncatedNormal" + input: "dnn/outputs/truncated_normal/stddev" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "dnn/outputs/truncated_normal" + op: "Add" + input: "dnn/outputs/truncated_normal/mul" + input: "dnn/outputs/truncated_normal/mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "dnn/outputs/weights" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/outputs/weights/Assign" + op: "Assign" + input: "dnn/outputs/weights" + input: "dnn/outputs/truncated_normal" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/outputs/weights/read" + op: "Identity" + input: "dnn/outputs/weights" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "dnn/outputs/zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "dnn/outputs/bias" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "dnn/outputs/bias/Assign" + op: "Assign" + input: "dnn/outputs/bias" + input: "dnn/outputs/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "dnn/outputs/bias/read" + op: "Identity" + input: "dnn/outputs/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "dnn/outputs/MatMul" + op: "MatMul" + input: "dnn/hidden2/Selu" + input: "dnn/outputs/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "dnn/outputs/add" + op: "Add" + input: "dnn/outputs/MatMul" + input: "dnn/outputs/bias/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "loss/SparseSoftmaxCrossEntropyWithLogits/Shape" + op: "Shape" + input: "y" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" + op: "SparseSoftmaxCrossEntropyWithLogits" + input: "dnn/outputs/add" + input: "y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tlabels" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "loss/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "loss/loss" + op: "Mean" + input: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" + input: "loss/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/Shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "train/gradients/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node { + name: "train/gradients/Fill" + op: "Fill" + input: "train/gradients/Shape" + input: "train/gradients/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Reshape/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Reshape" + op: "Reshape" + input: "train/gradients/Fill" + input: "train/gradients/loss/loss_grad/Reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Shape" + op: "Shape" + input: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/loss/loss_grad/Tile" + op: "Tile" + input: "train/gradients/loss/loss_grad/Reshape" + input: "train/gradients/loss/loss_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Shape_1" + op: "Shape" + input: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/loss/loss_grad/Shape_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Prod" + op: "Prod" + input: "train/gradients/loss/loss_grad/Shape_1" + input: "train/gradients/loss/loss_grad/Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/loss/loss_grad/Const_1" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Prod_1" + op: "Prod" + input: "train/gradients/loss/loss_grad/Shape_2" + input: "train/gradients/loss/loss_grad/Const_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/loss/loss_grad/Maximum/y" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Maximum" + op: "Maximum" + input: "train/gradients/loss/loss_grad/Prod_1" + input: "train/gradients/loss/loss_grad/Maximum/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/floordiv" + op: "FloorDiv" + input: "train/gradients/loss/loss_grad/Prod" + input: "train/gradients/loss/loss_grad/Maximum" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/loss/loss_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/Cast" + op: "Cast" + input: "train/gradients/loss/loss_grad/floordiv" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/loss/loss_grad/truediv" + op: "RealDiv" + input: "train/gradients/loss/loss_grad/Tile" + input: "train/gradients/loss/loss_grad/Cast" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/zeros_like" + op: "ZerosLike" + input: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/PreventGradient" + op: "PreventGradient" + input: "loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "message" + value { + s: "Currently there is no way to take the second derivative of sparse_softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" + } + } + } + node { + name: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } + } + node { + name: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "train/gradients/loss/loss_grad/truediv" + input: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/PreventGradient" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Shape" + op: "Shape" + input: "dnn/outputs/MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "train/gradients/dnn/outputs/add_grad/Shape" + input: "train/gradients/dnn/outputs/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Sum" + op: "Sum" + input: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/mul" + input: "train/gradients/dnn/outputs/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Reshape" + op: "Reshape" + input: "train/gradients/dnn/outputs/add_grad/Sum" + input: "train/gradients/dnn/outputs/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Sum_1" + op: "Sum" + input: "train/gradients/loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_grad/mul" + input: "train/gradients/dnn/outputs/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/Reshape_1" + op: "Reshape" + input: "train/gradients/dnn/outputs/add_grad/Sum_1" + input: "train/gradients/dnn/outputs/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/outputs/add_grad/Reshape" + input: "^train/gradients/dnn/outputs/add_grad/Reshape_1" + } + node { + name: "train/gradients/dnn/outputs/add_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/outputs/add_grad/Reshape" + input: "^train/gradients/dnn/outputs/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/outputs/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/outputs/add_grad/Reshape_1" + input: "^train/gradients/dnn/outputs/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/outputs/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/MatMul_grad/MatMul" + op: "MatMul" + input: "train/gradients/dnn/outputs/add_grad/tuple/control_dependency" + input: "dnn/outputs/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "train/gradients/dnn/outputs/MatMul_grad/MatMul_1" + op: "MatMul" + input: "dnn/hidden2/Selu" + input: "train/gradients/dnn/outputs/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/outputs/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/outputs/MatMul_grad/MatMul" + input: "^train/gradients/dnn/outputs/MatMul_grad/MatMul_1" + } + node { + name: "train/gradients/dnn/outputs/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/outputs/MatMul_grad/MatMul" + input: "^train/gradients/dnn/outputs/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/outputs/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/outputs/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/outputs/MatMul_grad/MatMul_1" + input: "^train/gradients/dnn/outputs/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/outputs/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/Selu_grad/SeluGrad" + op: "SeluGrad" + input: "train/gradients/dnn/outputs/MatMul_grad/tuple/control_dependency" + input: "dnn/hidden2/Selu" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Shape" + op: "Shape" + input: "dnn/hidden2/MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 100 + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "train/gradients/dnn/hidden2/add_grad/Shape" + input: "train/gradients/dnn/hidden2/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Sum" + op: "Sum" + input: "train/gradients/dnn/hidden2/Selu_grad/SeluGrad" + input: "train/gradients/dnn/hidden2/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Reshape" + op: "Reshape" + input: "train/gradients/dnn/hidden2/add_grad/Sum" + input: "train/gradients/dnn/hidden2/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Sum_1" + op: "Sum" + input: "train/gradients/dnn/hidden2/Selu_grad/SeluGrad" + input: "train/gradients/dnn/hidden2/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/Reshape_1" + op: "Reshape" + input: "train/gradients/dnn/hidden2/add_grad/Sum_1" + input: "train/gradients/dnn/hidden2/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden2/add_grad/Reshape" + input: "^train/gradients/dnn/hidden2/add_grad/Reshape_1" + } + node { + name: "train/gradients/dnn/hidden2/add_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden2/add_grad/Reshape" + input: "^train/gradients/dnn/hidden2/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden2/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden2/add_grad/Reshape_1" + input: "^train/gradients/dnn/hidden2/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden2/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/MatMul_grad/MatMul" + op: "MatMul" + input: "train/gradients/dnn/hidden2/add_grad/tuple/control_dependency" + input: "dnn/hidden2/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "train/gradients/dnn/hidden2/MatMul_grad/MatMul_1" + op: "MatMul" + input: "dnn/hidden1/Maximum" + input: "train/gradients/dnn/hidden2/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden2/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden2/MatMul_grad/MatMul" + input: "^train/gradients/dnn/hidden2/MatMul_grad/MatMul_1" + } + node { + name: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden2/MatMul_grad/MatMul" + input: "^train/gradients/dnn/hidden2/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden2/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden2/MatMul_grad/MatMul_1" + input: "^train/gradients/dnn/hidden2/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden2/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Shape" + op: "Shape" + input: "dnn/hidden1/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Shape_1" + op: "Shape" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Shape_2" + op: "Shape" + input: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/zeros/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/zeros" + op: "Fill" + input: "train/gradients/dnn/hidden1/Maximum_grad/Shape_2" + input: "train/gradients/dnn/hidden1/Maximum_grad/zeros/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/GreaterEqual" + op: "GreaterEqual" + input: "dnn/hidden1/mul" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "train/gradients/dnn/hidden1/Maximum_grad/Shape" + input: "train/gradients/dnn/hidden1/Maximum_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Select" + op: "Select" + input: "train/gradients/dnn/hidden1/Maximum_grad/GreaterEqual" + input: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency" + input: "train/gradients/dnn/hidden1/Maximum_grad/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Select_1" + op: "Select" + input: "train/gradients/dnn/hidden1/Maximum_grad/GreaterEqual" + input: "train/gradients/dnn/hidden1/Maximum_grad/zeros" + input: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Sum" + op: "Sum" + input: "train/gradients/dnn/hidden1/Maximum_grad/Select" + input: "train/gradients/dnn/hidden1/Maximum_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Reshape" + op: "Reshape" + input: "train/gradients/dnn/hidden1/Maximum_grad/Sum" + input: "train/gradients/dnn/hidden1/Maximum_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Sum_1" + op: "Sum" + input: "train/gradients/dnn/hidden1/Maximum_grad/Select_1" + input: "train/gradients/dnn/hidden1/Maximum_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/Reshape_1" + op: "Reshape" + input: "train/gradients/dnn/hidden1/Maximum_grad/Sum_1" + input: "train/gradients/dnn/hidden1/Maximum_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden1/Maximum_grad/Reshape" + input: "^train/gradients/dnn/hidden1/Maximum_grad/Reshape_1" + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden1/Maximum_grad/Reshape" + input: "^train/gradients/dnn/hidden1/Maximum_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/Maximum_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/Maximum_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden1/Maximum_grad/Reshape_1" + input: "^train/gradients/dnn/hidden1/Maximum_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/Maximum_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Shape_1" + op: "Shape" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "train/gradients/dnn/hidden1/mul_grad/Shape" + input: "train/gradients/dnn/hidden1/mul_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/mul" + op: "Mul" + input: "train/gradients/dnn/hidden1/Maximum_grad/tuple/control_dependency" + input: "dnn/hidden1/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Sum" + op: "Sum" + input: "train/gradients/dnn/hidden1/mul_grad/mul" + input: "train/gradients/dnn/hidden1/mul_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Reshape" + op: "Reshape" + input: "train/gradients/dnn/hidden1/mul_grad/Sum" + input: "train/gradients/dnn/hidden1/mul_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/mul_1" + op: "Mul" + input: "dnn/hidden1/mul/x" + input: "train/gradients/dnn/hidden1/Maximum_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Sum_1" + op: "Sum" + input: "train/gradients/dnn/hidden1/mul_grad/mul_1" + input: "train/gradients/dnn/hidden1/mul_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/Reshape_1" + op: "Reshape" + input: "train/gradients/dnn/hidden1/mul_grad/Sum_1" + input: "train/gradients/dnn/hidden1/mul_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden1/mul_grad/Reshape" + input: "^train/gradients/dnn/hidden1/mul_grad/Reshape_1" + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden1/mul_grad/Reshape" + input: "^train/gradients/dnn/hidden1/mul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/mul_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/mul_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden1/mul_grad/Reshape_1" + input: "^train/gradients/dnn/hidden1/mul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/mul_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/AddN" + op: "AddN" + input: "train/gradients/dnn/hidden1/Maximum_grad/tuple/control_dependency_1" + input: "train/gradients/dnn/hidden1/mul_grad/tuple/control_dependency_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/Maximum_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Shape" + op: "Shape" + input: "dnn/hidden1/MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 300 + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "train/gradients/dnn/hidden1/add_grad/Shape" + input: "train/gradients/dnn/hidden1/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Sum" + op: "Sum" + input: "train/gradients/AddN" + input: "train/gradients/dnn/hidden1/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Reshape" + op: "Reshape" + input: "train/gradients/dnn/hidden1/add_grad/Sum" + input: "train/gradients/dnn/hidden1/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Sum_1" + op: "Sum" + input: "train/gradients/AddN" + input: "train/gradients/dnn/hidden1/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/Reshape_1" + op: "Reshape" + input: "train/gradients/dnn/hidden1/add_grad/Sum_1" + input: "train/gradients/dnn/hidden1/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden1/add_grad/Reshape" + input: "^train/gradients/dnn/hidden1/add_grad/Reshape_1" + } + node { + name: "train/gradients/dnn/hidden1/add_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden1/add_grad/Reshape" + input: "^train/gradients/dnn/hidden1/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden1/add_grad/Reshape_1" + input: "^train/gradients/dnn/hidden1/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/MatMul_grad/MatMul" + op: "MatMul" + input: "train/gradients/dnn/hidden1/add_grad/tuple/control_dependency" + input: "dnn/hidden1/weights/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "train/gradients/dnn/hidden1/MatMul_grad/MatMul_1" + op: "MatMul" + input: "input" + input: "train/gradients/dnn/hidden1/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "train/gradients/dnn/hidden1/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^train/gradients/dnn/hidden1/MatMul_grad/MatMul" + input: "^train/gradients/dnn/hidden1/MatMul_grad/MatMul_1" + } + node { + name: "train/gradients/dnn/hidden1/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "train/gradients/dnn/hidden1/MatMul_grad/MatMul" + input: "^train/gradients/dnn/hidden1/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "train/gradients/dnn/hidden1/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "train/gradients/dnn/hidden1/MatMul_grad/MatMul_1" + input: "^train/gradients/dnn/hidden1/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@train/gradients/dnn/hidden1/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + } + node { + name: "train/GradientDescent/learning_rate" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.009999999776482582 + } + } + } + } + node { + name: "train/GradientDescent/update_dnn/hidden1/weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/hidden1/weights" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/hidden1/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent/update_dnn/hidden1/bias/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/hidden1/bias" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/hidden1/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent/update_dnn/hidden2/weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/hidden2/weights" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/hidden2/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent/update_dnn/hidden2/bias/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/hidden2/bias" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/hidden2/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent/update_dnn/outputs/weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/outputs/weights" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/outputs/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent/update_dnn/outputs/bias/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "dnn/outputs/bias" + input: "train/GradientDescent/learning_rate" + input: "train/gradients/dnn/outputs/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "train/GradientDescent" + op: "NoOp" + input: "^train/GradientDescent/update_dnn/hidden1/weights/ApplyGradientDescent" + input: "^train/GradientDescent/update_dnn/hidden1/bias/ApplyGradientDescent" + input: "^train/GradientDescent/update_dnn/hidden2/weights/ApplyGradientDescent" + input: "^train/GradientDescent/update_dnn/hidden2/bias/ApplyGradientDescent" + input: "^train/GradientDescent/update_dnn/outputs/weights/ApplyGradientDescent" + input: "^train/GradientDescent/update_dnn/outputs/bias/ApplyGradientDescent" + } + node { + name: "eval/in_top_k/InTopKV2/k" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node { + name: "eval/in_top_k/InTopKV2" + op: "InTopKV2" + input: "dnn/outputs/add" + input: "y" + input: "eval/in_top_k/InTopKV2/k" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "eval/Cast" + op: "Cast" + input: "eval/in_top_k/InTopKV2" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "eval/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "eval/Mean" + op: "Mean" + input: "eval/Cast" + input: "eval/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "init" + op: "NoOp" + input: "^dnn/hidden1/weights/Assign" + input: "^dnn/hidden1/bias/Assign" + input: "^dnn/hidden2/weights/Assign" + input: "^dnn/hidden2/bias/Assign" + input: "^dnn/outputs/weights/Assign" + input: "^dnn/outputs/bias/Assign" + } + node { + name: "Accuracy/tags" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Accuracy" + } + } + } + } + node { + name: "Accuracy" + op: "ScalarSummary" + input: "Accuracy/tags" + input: "eval/Mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "model" + } + } + } + } + node { + name: "save/StringJoin/inputs_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp_de3cfc5e8e7e4734ae221577e8fd36a2/part" + } + } + } + } + node { + name: "save/StringJoin" + op: "StringJoin" + input: "save/Const" + input: "save/StringJoin/inputs_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "separator" + value { + s: "" + } + } + } + node { + name: "save/num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "save/ShardedFilename/shard" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "save/ShardedFilename" + op: "ShardedFilename" + input: "save/StringJoin" + input: "save/ShardedFilename/shard" + input: "save/num_shards" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/SaveV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 6 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 6 + } + } + string_val: "dnn/hidden1/bias" + string_val: "dnn/hidden1/weights" + string_val: "dnn/hidden2/bias" + string_val: "dnn/hidden2/weights" + string_val: "dnn/outputs/bias" + string_val: "dnn/outputs/weights" + } + } + } + } + node { + name: "save/SaveV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 6 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 6 + } + } + string_val: "" + string_val: "" + string_val: "" + string_val: "" + string_val: "" + string_val: "" + } + } + } + } + node { + name: "save/SaveV2" + op: "SaveV2" + input: "save/ShardedFilename" + input: "save/SaveV2/tensor_names" + input: "save/SaveV2/shape_and_slices" + input: "dnn/hidden1/bias" + input: "dnn/hidden1/weights" + input: "dnn/hidden2/bias" + input: "dnn/hidden2/weights" + input: "dnn/outputs/bias" + input: "dnn/outputs/weights" + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + } + } + } + } + node { + name: "save/control_dependency" + op: "Identity" + input: "save/ShardedFilename" + input: "^save/SaveV2" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@save/ShardedFilename" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "save/ShardedFilename" + input: "^save/control_dependency" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "save/MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "save/MergeV2Checkpoints/checkpoint_prefixes" + input: "save/Const" + attr { + key: "delete_old_dirs" + value { + b: true + } + } + } + node { + name: "save/Identity" + op: "Identity" + input: "save/Const" + input: "^save/control_dependency" + input: "^save/MergeV2Checkpoints" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/RestoreV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/hidden1/bias" + } + } + } + } + node { + name: "save/RestoreV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2/tensor_names" + input: "save/RestoreV2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign" + op: "Assign" + input: "dnn/hidden1/bias" + input: "save/RestoreV2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_1/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/hidden1/weights" + } + } + } + } + node { + name: "save/RestoreV2_1/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_1" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_1/tensor_names" + input: "save/RestoreV2_1/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_1" + op: "Assign" + input: "dnn/hidden1/weights" + input: "save/RestoreV2_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden1/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 300 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/hidden2/bias" + } + } + } + } + node { + name: "save/RestoreV2_2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_2/tensor_names" + input: "save/RestoreV2_2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_2" + op: "Assign" + input: "dnn/hidden2/bias" + input: "save/RestoreV2_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_3/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/hidden2/weights" + } + } + } + } + node { + name: "save/RestoreV2_3/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_3" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_3/tensor_names" + input: "save/RestoreV2_3/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_3" + op: "Assign" + input: "dnn/hidden2/weights" + input: "save/RestoreV2_3" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/hidden2/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 300 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_4/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/outputs/bias" + } + } + } + } + node { + name: "save/RestoreV2_4/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_4" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_4/tensor_names" + input: "save/RestoreV2_4/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_4" + op: "Assign" + input: "dnn/outputs/bias" + input: "save/RestoreV2_4" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_5/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "dnn/outputs/weights" + } + } + } + } + node { + name: "save/RestoreV2_5/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_5" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_5/tensor_names" + input: "save/RestoreV2_5/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_5" + op: "Assign" + input: "dnn/outputs/weights" + input: "save/RestoreV2_5" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@dnn/outputs/weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/restore_shard" + op: "NoOp" + input: "^save/Assign" + input: "^save/Assign_1" + input: "^save/Assign_2" + input: "^save/Assign_3" + input: "^save/Assign_4" + input: "^save/Assign_5" + } + node { + name: "save/restore_all" + op: "NoOp" + input: "^save/restore_shard" + } + versions { + producer: 24 + } + } + saver_def { + filename_tensor_name: "save/Const:0" + save_tensor_name: "save/Identity:0" + restore_op_name: "save/restore_all" + max_to_keep: 5 + sharded: true + keep_checkpoint_every_n_hours: 10000.0 + version: V2 + } + collection_def { + key: "summaries" + value { + node_list { + value: "Accuracy:0" + } + } + } + collection_def { + key: "train_op" + value { + node_list { + value: "train/GradientDescent" + } + } + } + collection_def { + key: "trainable_variables" + value { + bytes_list { + value: "\n\025dnn/hidden1/weights:0\022\032dnn/hidden1/weights/Assign\032\032dnn/hidden1/weights/read:02\036dnn/hidden1/truncated_normal:0" + value: "\n\022dnn/hidden1/bias:0\022\027dnn/hidden1/bias/Assign\032\027dnn/hidden1/bias/read:02\023dnn/hidden1/zeros:0" + value: "\n\025dnn/hidden2/weights:0\022\032dnn/hidden2/weights/Assign\032\032dnn/hidden2/weights/read:02\036dnn/hidden2/truncated_normal:0" + value: "\n\022dnn/hidden2/bias:0\022\027dnn/hidden2/bias/Assign\032\027dnn/hidden2/bias/read:02\023dnn/hidden2/zeros:0" + value: "\n\025dnn/outputs/weights:0\022\032dnn/outputs/weights/Assign\032\032dnn/outputs/weights/read:02\036dnn/outputs/truncated_normal:0" + value: "\n\022dnn/outputs/bias:0\022\027dnn/outputs/bias/Assign\032\027dnn/outputs/bias/read:02\023dnn/outputs/zeros:0" + } + } + } + collection_def { + key: "variables" + value { + bytes_list { + value: "\n\025dnn/hidden1/weights:0\022\032dnn/hidden1/weights/Assign\032\032dnn/hidden1/weights/read:02\036dnn/hidden1/truncated_normal:0" + value: "\n\022dnn/hidden1/bias:0\022\027dnn/hidden1/bias/Assign\032\027dnn/hidden1/bias/read:02\023dnn/hidden1/zeros:0" + value: "\n\025dnn/hidden2/weights:0\022\032dnn/hidden2/weights/Assign\032\032dnn/hidden2/weights/read:02\036dnn/hidden2/truncated_normal:0" + value: "\n\022dnn/hidden2/bias:0\022\027dnn/hidden2/bias/Assign\032\027dnn/hidden2/bias/read:02\023dnn/hidden2/zeros:0" + value: "\n\025dnn/outputs/weights:0\022\032dnn/outputs/weights/Assign\032\032dnn/outputs/weights/read:02\036dnn/outputs/truncated_normal:0" + value: "\n\022dnn/outputs/bias:0\022\027dnn/outputs/bias/Assign\032\027dnn/outputs/bias/read:02\023dnn/outputs/zeros:0" + } + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "x" + value { + name: "input:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + outputs { + key: "y" + value { + name: "dnn/outputs/add:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 Binary files differnew file mode 100644 index 00000000000..a7ca01888c7 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index Binary files differnew file mode 100644 index 00000000000..7989c109a3a --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py b/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py new file mode 100644 index 00000000000..26529f67919 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py @@ -0,0 +1,97 @@ + +# Common imports +import numpy as np +import tensorflow as tf + +from tensorflow.examples.tutorials.mnist import input_data +from datetime import datetime + +now = datetime.utcnow().strftime("%Y%m%d%H%M%S") +root_logdir = "tf_logs" +logdir = "{}/run-{}/".format(root_logdir, now) + +mnist = input_data.read_data_sets("/tmp/data/") +X_train = mnist.train.images +X_test = mnist.test.images +y_train = mnist.train.labels.astype("int") +y_test = mnist.test.labels.astype("int") + +n_inputs = 28*28 # MNIST +n_hidden1 = 300 +n_hidden2 = 100 +n_hidden3 = 40 +n_outputs = 10 + +learning_rate = 0.01 +n_epochs = 20 +batch_size = 50 + +input = tf.placeholder(tf.float32, shape=(None, n_inputs), name="input") +y = tf.placeholder(tf.int64, shape=(None), name="y") + + +def neuron_layer(X, n_neurons, name, activation=None): + with tf.name_scope(name): + n_inputs = int(X.get_shape()[1]) + stddev = 2 / np.sqrt(n_inputs) + init = tf.truncated_normal((n_inputs, n_neurons), stddev=stddev) + W = tf.Variable(init, name="weights") + b = tf.Variable(tf.zeros([n_neurons]), name="bias") + Z = tf.matmul(X, W) + b + if activation is not None: + return activation(Z) + else: + return Z + + +def leaky_relu(z, name=None): + return tf.maximum(0.01 * z, z, name=name) + + +with tf.name_scope("dnn"): + hidden1 = neuron_layer(input, n_hidden1, name="hidden1", activation=leaky_relu) + hidden2 = neuron_layer(hidden1, n_hidden2, name="hidden2", activation=tf.nn.selu) + logits = neuron_layer(hidden2, n_outputs, name="outputs") #, activation=tf.nn.sigmoid) + +with tf.name_scope("loss"): + xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits) + loss = tf.reduce_mean(xentropy, name="loss") + +with tf.name_scope("train"): + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + training_op = optimizer.minimize(loss) + +with tf.name_scope("eval"): + correct = tf.nn.in_top_k(logits, y, 1) + accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) + +init = tf.global_variables_initializer() +accuracy_summary = tf.summary.scalar('Accuracy', accuracy) +file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph()) + +with tf.Session() as sess: + init.run() + for epoch in range(n_epochs): + for iteration in range(mnist.train.num_examples // batch_size): + X_batch, y_batch = mnist.train.next_batch(batch_size) + sess.run(training_op, feed_dict={input: X_batch, y: y_batch}) + acc_train = accuracy.eval(feed_dict={input: X_batch, y: y_batch}) + acc_val = accuracy.eval(feed_dict={input: mnist.validation.images, + y: mnist.validation.labels}) + print(epoch, "Train accuracy:", acc_train, "Val accuracy:", acc_val) + + # Save summary for tensorboard + summary_str = accuracy_summary.eval(feed_dict={input: mnist.validation.images, + y: mnist.validation.labels}) + file_writer.add_summary(summary_str, epoch) + + export_path = "saved" + print('Exporting trained model to ', export_path) + builder = tf.saved_model.builder.SavedModelBuilder(export_path) + signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':input}, outputs = {'y':logits}) + builder.add_meta_graph_and_variables(sess, + [tf.saved_model.tag_constants.SERVING], + signature_def_map={'serving_default':signature}) + builder.save(as_text=True) + +file_writer.close()
\ No newline at end of file diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java new file mode 100644 index 00000000000..add66eece1a --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class MnistImportTestCase { + + @Test + public void testMnistImport() { + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist/saved"); + ImportedModel.Signature signature = model.get().signature("serving_default"); + + assertEquals("Has skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); + + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("dnn/outputs/add", output.getName()); + model.assertEqualResultSum("input", output.getName(), 0.00001); + } + + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index a7926cd2e02..bcfc6ce0a04 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -7,9 +7,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; -import org.tensorflow.SavedModelBundle; - -import java.io.IOException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -21,7 +18,7 @@ import static org.junit.Assert.assertTrue; public class OnnxMnistSoftmaxImportTestCase { @Test - public void testMnistSoftmaxImport() throws IOException { + public void testMnistSoftmaxImport() { ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants @@ -43,14 +40,14 @@ public class OnnxMnistSoftmaxImportTestCase { assertEquals(1, model.requiredMacros().size()); assertTrue(model.requiredMacros().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredMacros().get("Placeholder")); + model.requiredMacros().get("Placeholder")); // Check outputs RankingExpression output = model.defaultSignature().outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.getRoot().toString()); + output.getRoot().toString()); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index bd7644be23b..dd6c8095e3c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -13,7 +13,7 @@ import static org.junit.Assert.assertTrue; /** * @author bratseth */ -public class MnistSoftmaxImportTestCase { +public class TensorFlowMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java index 723c5f27914..4de3aa5d635 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java @@ -36,11 +36,26 @@ public class TestableTensorFlowModel { public TestableTensorFlowModel(String modelName, String modelDir) { tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); - model = new TensorFlowImporter().importModel(modelName, tensorFlowModel); + model = new TensorFlowImporter().importModel(modelName, modelDir, tensorFlowModel); } public ImportedModel get() { return model; } + /** Compare that summing the tensors produce the same result to within some tolerance delta */ + public void assertEqualResultSum(String inputName, String operationName, double delta) { + Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); + Context context = contextFrom(model); + Tensor placeholder = placeholderArgument(); + context.put(inputName, new TensorValue(placeholder)); + + model.macros().forEach((k,v) -> evaluateMacro(context, model, k)); + + Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); + assertEquals("Operation '" + operationName + "' produces equal results", + tfResult.sum().asDouble(), vespaResult.sum().asDouble(), delta); + } + + /** Compare tensors 100% exactly */ public void assertEqualResult(String inputName, String operationName) { Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); Context context = contextFrom(model); diff --git a/searchlib/src/tests/expression/attributenode/CMakeLists.txt b/searchlib/src/tests/expression/attributenode/CMakeLists.txt new file mode 100644 index 00000000000..c7df5458bb7 --- /dev/null +++ b/searchlib/src/tests/expression/attributenode/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(searchlib_attribute_node_test_app TEST + SOURCES + attribute_node_test.cpp + DEPENDS + searchlib +) +vespa_add_test(NAME searchlib_attribute_node_test_app COMMAND searchlib_attribute_node_test_app) diff --git a/searchlib/src/tests/expression/attributenode/attribute_node_test.cpp b/searchlib/src/tests/expression/attributenode/attribute_node_test.cpp new file mode 100644 index 00000000000..c92b5fc4808 --- /dev/null +++ b/searchlib/src/tests/expression/attributenode/attribute_node_test.cpp @@ -0,0 +1,429 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchcommon/common/undefinedvalues.h> +#include <vespa/searchlib/attribute/attributefactory.h> +#include <vespa/searchlib/attribute/attributecontext.h> +#include <vespa/searchlib/attribute/attributemanager.h> +#include <vespa/searchlib/attribute/attributevector.h> +#include <vespa/searchlib/attribute/attributevector.hpp> +#include <vespa/searchlib/attribute/floatbase.h> +#include <vespa/searchlib/attribute/integerbase.h> +#include <vespa/searchlib/attribute/stringbase.h> +#include <vespa/searchlib/expression/attributenode.h> +#include <vespa/searchlib/expression/attribute_keyed_node.h> +#include <vespa/searchlib/expression/resultvector.h> +#include <vespa/vespalib/test/insertion_operators.h> +#include <vespa/vespalib/testkit/testapp.h> + +#include <vespa/log/log.h> +LOG_SETUP("attribute_node_test"); + +using search::AttributeContext; +using search::AttributeFactory; +using search::AttributeManager; +using search::AttributeVector; +using search::IntegerAttribute; +using search::FloatingPointAttribute; +using search::StringAttribute; +using search::attribute::BasicType; +using search::attribute::CollectionType; +using search::attribute::Config; +using search::attribute::IAttributeVector; +using search::attribute::getUndefined; +using search::expression::AttributeNode; +using search::expression::AttributeKeyedNode; +using search::expression::EnumResultNode; +using search::expression::EnumResultNodeVector; +using search::expression::FloatResultNode; +using search::expression::FloatResultNodeVector; +using search::expression::Int8ResultNode; +using search::expression::Int8ResultNodeVector; +using search::expression::IntegerResultNodeVector; +using search::expression::IntegerResultNode; +using search::expression::ResultNode; +using search::expression::ResultNodeVector; +using search::expression::StringResultNode; +using search::expression::StringResultNodeVector; +using vespalib::BufferRef; + +namespace { + +vespalib::string stringValue(const ResultNode &result, const IAttributeVector &attr) { + if (result.inherits(EnumResultNode::classId)) { + auto enumHandle = result.getEnum(); + auto &stringAttr = dynamic_cast<const StringAttribute &>(attr); + return vespalib::string(stringAttr.getFromEnum(enumHandle)); + } + char buf[100]; + BufferRef bref(&buf[0], sizeof(buf)); + auto sbuf = result.getString(bref); + return vespalib::string(sbuf.c_str(), sbuf.c_str() + sbuf.size()); +} + +struct AttributeManagerFixture +{ + AttributeManager mgr; + + AttributeManagerFixture(); + ~AttributeManagerFixture(); + template <typename AttributeType, typename ValueType> + void buildAttribute(const vespalib::string &name, BasicType type, std::vector<ValueType> values); + void buildStringAttribute(const vespalib::string &name, std::vector<vespalib::string> values); + void buildFloatAttribute(const vespalib::string &name, std::vector<double> values); + void buildIntegerAttribute(const vespalib::string &name, BasicType type, std::vector<IAttributeVector::largeint_t> values); + template <typename AttributeType, typename ValueType> + void buildArrayAttribute(const vespalib::string &name, BasicType type, std::vector<std::vector<ValueType>> values); + void buildStringArrayAttribute(const vespalib::string &name,std::vector<std::vector<vespalib::string>> values); + void buildFloatArrayAttribute(const vespalib::string &name, std::vector<std::vector<double>> values); + void buildIntegerArrayAttribute(const vespalib::string &name, BasicType type, std::vector<std::vector<IAttributeVector::largeint_t>> values); +}; + +AttributeManagerFixture::AttributeManagerFixture() + : mgr() +{ + buildStringAttribute("sfield", { "n1", ""}); + buildIntegerAttribute("ifield", BasicType::Type::INT8, { 10, getUndefined<int8_t>() }); + buildFloatAttribute("ffield", { 110.0, getUndefined<double>() }); + buildStringArrayAttribute("array.name", {{"n1.1", "n1.2"}, {"n2"}, {}}); + buildIntegerArrayAttribute("array.val", BasicType::Type::INT8, {{ 10, 11}, {20, 21 }, {}}); + buildFloatArrayAttribute("array.fval", {{ 110.0}, { 120.0, 121.0 }, {}}); + buildStringArrayAttribute("smap.key", {{"k1.1", "k1.2"}, {"k2"}, {}}); + buildStringArrayAttribute("smap.value.name", {{"n1.1", "n1.2"}, {"n2"}, {}}); + buildIntegerArrayAttribute("smap.value.val", BasicType::Type::INT8, {{ 10, 11}, {20, 21 }, {}}); + buildFloatArrayAttribute("smap.value.fval", {{ 110.0}, { 120.0, 121.0 }, {}}); + buildStringArrayAttribute("map.key", {{"k1.1", "k1.2"}, {"k2"}, {}}); + buildStringArrayAttribute("map.value", {{"n1.1", "n1.2"}, {"n2"}, {}}); + buildStringAttribute("keyfield1", {"k1.2", "k2", "k3"}); + buildStringAttribute("keyfield2", {"k1.1", "k1", "k1"}); +} + +AttributeManagerFixture::~AttributeManagerFixture() = default; + +template <typename AttributeType, typename ValueType> +void +AttributeManagerFixture::buildAttribute(const vespalib::string &name, + BasicType type, + std::vector<ValueType> values) +{ + Config cfg(type, CollectionType::Type::SINGLE); + auto attrBase = AttributeFactory::createAttribute(name, cfg); + EXPECT_TRUE(attrBase); + auto attr = std::dynamic_pointer_cast<AttributeType>(attrBase); + EXPECT_TRUE(attr); + attr->addReservedDoc(); + for (const auto &value : values) { + uint32_t docId = 0; + EXPECT_TRUE(attr->addDoc(docId)); + EXPECT_NOT_EQUAL(0u, docId); + attr->update(docId, value); + attr->commit(); + } + EXPECT_TRUE(mgr.add(attr)); +} + +void +AttributeManagerFixture::buildStringAttribute(const vespalib::string &name, + std::vector<vespalib::string> values) +{ + buildAttribute<StringAttribute, vespalib::string>(name, BasicType::Type::STRING, std::move(values)); +} + +void +AttributeManagerFixture::buildFloatAttribute(const vespalib::string &name, + std::vector<double> values) +{ + buildAttribute<FloatingPointAttribute, double>(name, BasicType::Type::DOUBLE, std::move(values)); +} + +void +AttributeManagerFixture::buildIntegerAttribute(const vespalib::string &name, + BasicType type, + std::vector<IAttributeVector::largeint_t> values) +{ + buildAttribute<IntegerAttribute, IAttributeVector::largeint_t>(name, type, std::move(values)); +} + +template <typename AttributeType, typename ValueType> +void +AttributeManagerFixture::buildArrayAttribute(const vespalib::string &name, + BasicType type, + std::vector<std::vector<ValueType>> values) +{ + Config cfg(type, CollectionType::Type::ARRAY); + auto attrBase = AttributeFactory::createAttribute(name, cfg); + EXPECT_TRUE(attrBase); + auto attr = std::dynamic_pointer_cast<AttributeType>(attrBase); + EXPECT_TRUE(attr); + attr->addReservedDoc(); + for (const auto &docValues : values) { + uint32_t docId = 0; + EXPECT_TRUE(attr->addDoc(docId)); + EXPECT_NOT_EQUAL(0u, docId); + for (const auto &value : docValues) { + attr->append(docId, value, 1); + } + attr->commit(); + } + EXPECT_TRUE(mgr.add(attr)); +} + +void +AttributeManagerFixture::buildStringArrayAttribute(const vespalib::string &name, + std::vector<std::vector<vespalib::string>> values) +{ + buildArrayAttribute<StringAttribute, vespalib::string>(name, BasicType::Type::STRING, std::move(values)); +} + +void +AttributeManagerFixture::buildFloatArrayAttribute(const vespalib::string &name, + std::vector<std::vector<double>> values) +{ + buildArrayAttribute<FloatingPointAttribute, double>(name, BasicType::Type::DOUBLE, std::move(values)); +} + +void +AttributeManagerFixture::buildIntegerArrayAttribute(const vespalib::string &name, + BasicType type, + std::vector<std::vector<IAttributeVector::largeint_t>> values) +{ + buildArrayAttribute<IntegerAttribute, IAttributeVector::largeint_t>(name, type, std::move(values)); +} + + +struct Fixture +{ + AttributeManagerFixture attrs; + AttributeContext context; + Fixture(); + ~Fixture(); + std::unique_ptr<AttributeNode> makeNode(const vespalib::string &attributeName, bool useEnumOptimiation = false, bool preserveAccurateTypes = false); + void assertInts(std::vector<IAttributeVector::largeint_t> expVals, const vespalib::string &attributteName, bool preserveAccurateTypes = false); + void assertStrings(std::vector<vespalib::string> expVals, const vespalib::string &attributteName, bool useEnumOptimization = false); + void assertFloats(std::vector<double> expVals, const vespalib::string &attributteName); + void assertIntArrays(std::vector<std::vector<IAttributeVector::largeint_t>> expVals, const vespalib::string &attributteName, bool preserveAccurateTypes = false); + void assertStringArrays(std::vector<std::vector<vespalib::string>> expVals, const vespalib::string &attributteName, bool useEnumOptimization = false); + void assertFloatArrays(std::vector<std::vector<double>> expVals, const vespalib::string &attributteName); +}; + +Fixture::Fixture() + : attrs(), + context(attrs.mgr) +{ +} + +Fixture::~Fixture() = default; + +std::unique_ptr<AttributeNode> +Fixture::makeNode(const vespalib::string &attributeName, bool useEnumOptimization, bool preserveAccurateTypes) +{ + std::unique_ptr<AttributeNode> node; + if (attributeName.find('{') == vespalib::string::npos) { + node = std::make_unique<AttributeNode>(attributeName); + } else { + node = std::make_unique<AttributeKeyedNode>(attributeName); + } + if (useEnumOptimization) { + node->useEnumOptimization(); + } + AttributeNode::Configure configure(context); + node->select(configure, configure); + node->prepare(preserveAccurateTypes); + return node; +} + + +void +Fixture::assertInts(std::vector<IAttributeVector::largeint_t> expVals, const vespalib::string &attributeName, bool preserveAccurateTypes) +{ + auto node = makeNode(attributeName, false, preserveAccurateTypes); + uint32_t docId = 0; + for (const auto &expDocVal : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + if (preserveAccurateTypes) { + ASSERT_TRUE(result.inherits(Int8ResultNode::classId)); + } else { + ASSERT_TRUE(result.inherits(IntegerResultNode::classId)); + } + IAttributeVector::largeint_t docVal = result.getInteger(); + EXPECT_EQUAL(expDocVal, docVal); + } +} + +void +Fixture::assertStrings(std::vector<vespalib::string> expVals, const vespalib::string &attributeName, bool useEnumOptimization) +{ + auto node = makeNode(attributeName, useEnumOptimization); + uint32_t docId = 0; + for (const auto &expDocVal : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + if (useEnumOptimization) { + ASSERT_TRUE(result.inherits(EnumResultNode::classId)); + } else { + ASSERT_TRUE(result.inherits(StringResultNode::classId)); + } + vespalib::string docVal = stringValue(result, *node->getAttribute()); + EXPECT_EQUAL(expDocVal, docVal); + } +} + +void +Fixture::assertFloats(std::vector<double> expVals, const vespalib::string &attributeName) +{ + auto node = makeNode(attributeName); + uint32_t docId = 0; + for (const auto &expDocVal : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + ASSERT_TRUE(result.inherits(FloatResultNode::classId)); + double docVal = result.getFloat(); + EXPECT_EQUAL(std::isnan(expDocVal), std::isnan(docVal)); + if (!std::isnan(expDocVal)) { + EXPECT_EQUAL(expDocVal, docVal); + } + } +} + +void +Fixture::assertIntArrays(std::vector<std::vector<IAttributeVector::largeint_t>> expVals, const vespalib::string &attributeName, bool preserveAccurateTypes) +{ + auto node = makeNode(attributeName, false, preserveAccurateTypes); + uint32_t docId = 0; + for (const auto &expDocVals : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + ASSERT_TRUE(result.inherits(ResultNodeVector::classId)); + const auto &resultVector = static_cast<const ResultNodeVector &>(result); + if (preserveAccurateTypes) { + ASSERT_TRUE(result.inherits(Int8ResultNodeVector::classId)); + } else { + ASSERT_TRUE(result.inherits(IntegerResultNodeVector::classId)); + } + std::vector<IAttributeVector::largeint_t> docVals; + for (size_t i = 0; i < resultVector.size(); ++i) { + docVals.push_back(resultVector.get(i).getInteger()); + } + EXPECT_EQUAL(expDocVals, docVals); + } +} + +void +Fixture::assertStringArrays(std::vector<std::vector<vespalib::string>> expVals, const vespalib::string &attributeName, bool useEnumOptimization) +{ + auto node = makeNode(attributeName, useEnumOptimization); + uint32_t docId = 0; + for (const auto &expDocVals : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + ASSERT_TRUE(result.inherits(ResultNodeVector::classId)); + const auto &resultVector = static_cast<const ResultNodeVector &>(result); + if (useEnumOptimization) { + ASSERT_TRUE(result.inherits(EnumResultNodeVector::classId)); + } else { + ASSERT_TRUE(result.inherits(StringResultNodeVector::classId)); + } + std::vector<vespalib::string> docVals; + for (size_t i = 0; i < resultVector.size(); ++i) { + docVals.push_back(stringValue(resultVector.get(i), *node->getAttribute())); + } + EXPECT_EQUAL(expDocVals, docVals); + } +} + +void +Fixture::assertFloatArrays(std::vector<std::vector<double>> expVals, const vespalib::string &attributeName) +{ + auto node = makeNode(attributeName); + uint32_t docId = 0; + for (const auto &expDocVals : expVals) { + ++docId; + node->setDocId(docId); + node->execute(); + const auto &result = node->getResult(); + ASSERT_TRUE(result.inherits(ResultNodeVector::classId)); + const auto &resultVector = static_cast<const ResultNodeVector &>(result); + ASSERT_TRUE(result.inherits(FloatResultNodeVector::classId)); + std::vector<double> docVals; + for (size_t i = 0; i < resultVector.size(); ++i) { + docVals.push_back(resultVector.get(i).getFloat()); + } + EXPECT_EQUAL(expDocVals.size(), docVals.size()); + for (size_t i = 0; i < expDocVals.size(); ++i) { + EXPECT_EQUAL(std::isnan(expDocVals[i]), std::isnan(docVals[i])); + if (!std::isnan(expDocVals[i])) { + EXPECT_EQUAL(expDocVals[i], docVals[i]); + } + } + } +} + +TEST_F("test single values", Fixture) +{ + TEST_DO(f.assertInts({ 10, getUndefined<int8_t>()}, "ifield")); + TEST_DO(f.assertInts({ 10, getUndefined<int8_t>()}, "ifield", true)); + TEST_DO(f.assertStrings({ "n1", "" }, "sfield")); + TEST_DO(f.assertStrings({ "n1", "" }, "sfield", true)); + TEST_DO(f.assertFloats({ 110.0, getUndefined<double>() }, "ffield")); +} + +TEST_F("Test array values", Fixture) +{ + TEST_DO(f.assertIntArrays({{ 10, 11}, {20, 21 }, {}}, "array.val")); + TEST_DO(f.assertIntArrays({{ 10, 11}, {20, 21 }, {}}, "array.val", true)); + TEST_DO(f.assertStringArrays({{"n1.1", "n1.2"}, {"n2"}, {}}, "array.name")); + TEST_DO(f.assertStringArrays({{"n1.1", "n1.2"}, {"n2"}, {}}, "array.name", true)); + TEST_DO(f.assertFloatArrays({{ 110.0}, { 120.0, 121.0 }, {}}, "array.fval")); + TEST_DO(f.assertStringArrays({{"k1.1", "k1.2"}, {"k2"}, {}}, "smap.key")); + TEST_DO(f.assertStringArrays({{"n1.1", "n1.2"}, {"n2"}, {}}, "smap.value.name")); + TEST_DO(f.assertIntArrays({{ 10, 11}, {20, 21 }, {}}, "smap.value.val")); + TEST_DO(f.assertFloatArrays({{ 110.0}, { 120.0, 121.0 }, {}}, "smap.value.fval")); + TEST_DO(f.assertStringArrays({{"k1.1", "k1.2"}, {"k2"}, {}}, "map.key")); + TEST_DO(f.assertStringArrays({{"n1.1", "n1.2"}, {"n2"}, {}}, "map.value")); +} + +TEST_F("test keyed values", Fixture) +{ + TEST_DO(f.assertStrings({"n1.1", "", ""}, "smap{\"k1.1\"}.name")); + TEST_DO(f.assertStrings({"n1.2", "", ""}, "smap{\"k1.2\"}.name")); + TEST_DO(f.assertStrings({"", "n2", ""}, "smap{\"k2\"}.name")); + TEST_DO(f.assertStrings({"", "", ""}, "smap{\"k5\"}.name")); + TEST_DO(f.assertFloats({ 110.0, getUndefined<double>(), getUndefined<double>()}, "smap{\"k1.1\"}.fval")); + TEST_DO(f.assertFloats({ getUndefined<double>(), getUndefined<double>(), getUndefined<double>()}, "smap{\"k1.2\"}.fval")); + TEST_DO(f.assertFloats({ getUndefined<double>(), 120.0, getUndefined<double>()}, "smap{\"k2\"}.fval")); + TEST_DO(f.assertFloats({ getUndefined<double>(), getUndefined<double>(), getUndefined<double>()}, "smap{\"k5\"}.fval")); + TEST_DO(f.assertInts({ 10, getUndefined<int8_t>(), getUndefined<int8_t>()}, "smap{\"k1.1\"}.val")); + TEST_DO(f.assertInts({ 11, getUndefined<int8_t>(), getUndefined<int8_t>()}, "smap{\"k1.2\"}.val")); + TEST_DO(f.assertInts({ getUndefined<int8_t>(), 20, getUndefined<int8_t>()}, "smap{\"k2\"}.val")); + TEST_DO(f.assertInts({ getUndefined<int8_t>(), getUndefined<int8_t>(), getUndefined<int8_t>()}, "smap{\"k5\"}.val")); + TEST_DO(f.assertStrings({"n1.1", "", ""}, "map{\"k1.1\"}")); + TEST_DO(f.assertStrings({"n1.2", "", ""}, "map{\"k1.2\"}")); + TEST_DO(f.assertStrings({"", "n2", ""}, "map{\"k2\"}")); + TEST_DO(f.assertStrings({"", "", ""}, "map{\"k5\"}")); +} + +TEST_F("test indirectly keyed values", Fixture) +{ + TEST_DO(f.assertStrings({"n1.2", "n2", ""}, "map{attribute(keyfield1)}")); + TEST_DO(f.assertStrings({"n1.1", "", ""}, "map{attribute(keyfield2)}")); + TEST_DO(f.assertStrings({"n1.2", "n2", ""}, "smap{attribute(keyfield1)}.name")); + TEST_DO(f.assertStrings({"n1.1", "", ""}, "smap{attribute(keyfield2)}.name")); + TEST_DO(f.assertFloats({ getUndefined<double>(), 120.0, getUndefined<double>()}, "smap{attribute(keyfield1)}.fval")); + TEST_DO(f.assertFloats({ 110.0, getUndefined<double>(), getUndefined<double>()}, "smap{attribute(keyfield2)}.fval")); + TEST_DO(f.assertInts({ 11, 20, getUndefined<int8_t>()}, "smap{attribute(keyfield1)}.val")); + TEST_DO(f.assertInts({ 10, getUndefined<int8_t>(), getUndefined<int8_t>()}, "smap{attribute(keyfield2)}.val")); +} + +} + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp b/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp index a59bbe14404..fe71484c4e6 100644 --- a/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp +++ b/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp @@ -4,6 +4,7 @@ #include "grouping.h" #include <vespa/searchlib/expression/multiargfunctionnode.h> #include <vespa/searchlib/expression/attributenode.h> +#include <vespa/searchlib/expression/attribute_keyed_node.h> #include <vespa/searchlib/expression/documentfieldnode.h> using namespace search::expression; @@ -63,6 +64,18 @@ Attribute2DocumentAccessor::getReplacementNode(const AttributeNode &attributeNod return std::make_unique<DocumentFieldNode>(attributeNode.getAttributeName()); } +std::unique_ptr<ExpressionNode> +Attribute2AttributeKeyed::getReplacementNode(const AttributeNode &attributeNode) +{ + const vespalib::string &attributeName = attributeNode.getAttributeName(); + auto lBracePos = attributeName.find('{'); + if (attributeNode.isKeyed() || lBracePos == vespalib::string::npos) { + return std::unique_ptr<ExpressionNode>(); + } else { + return std::make_unique<AttributeKeyedNode>(attributeName); + } +} + } // this function was added by ../../forcelink.sh diff --git a/searchlib/src/vespa/searchlib/aggregation/modifiers.h b/searchlib/src/vespa/searchlib/aggregation/modifiers.h index 0120cb4eac9..6ffda313904 100644 --- a/searchlib/src/vespa/searchlib/aggregation/modifiers.h +++ b/searchlib/src/vespa/searchlib/aggregation/modifiers.h @@ -28,4 +28,10 @@ private: std::unique_ptr<search::expression::ExpressionNode> getReplacementNode(const search::expression::AttributeNode &attributeNode) override; }; +class Attribute2AttributeKeyed : public AttributeNodeReplacer +{ +private: + std::unique_ptr<search::expression::ExpressionNode> getReplacementNode(const search::expression::AttributeNode &attributeNode) override; +}; + } diff --git a/searchlib/src/vespa/searchlib/attribute/attributevector.h b/searchlib/src/vespa/searchlib/attribute/attributevector.h index 8cf4079ccfa..54a43bec09e 100644 --- a/searchlib/src/vespa/searchlib/attribute/attributevector.h +++ b/searchlib/src/vespa/searchlib/attribute/attributevector.h @@ -461,13 +461,6 @@ public: virtual uint32_t clearDoc(DocId doc) = 0; virtual largeint_t getDefaultValue() const = 0; - virtual void getEnumValue(const EnumHandle *v, uint32_t *e, uint32_t sz) const = 0; - - uint32_t getEnumValue(EnumHandle eh) const { - uint32_t e(0); - getEnumValue(&eh, &e, 1); - return e; - } // Implements IAttributeVector virtual uint32_t get(DocId doc, EnumHandle *v, uint32_t sz) const override = 0; diff --git a/searchlib/src/vespa/searchlib/attribute/attrvector.h b/searchlib/src/vespa/searchlib/attribute/attrvector.h index 2ba9ed083f0..c0530ee8368 100644 --- a/searchlib/src/vespa/searchlib/attribute/attrvector.h +++ b/searchlib/src/vespa/searchlib/attribute/attrvector.h @@ -34,11 +34,6 @@ private: NumericDirectAttribute & operator=(const NumericDirectAttribute &); bool onLoad() override; typename B::BaseType getFromEnum(EnumHandle e) const override { return _data[e]; } - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { - for (size_t i(0); i < sz; i++) { - e[i] = v[i]; - } - } protected: typedef typename B::BaseType BaseType; typedef typename B::DocId DocId; @@ -153,11 +148,6 @@ protected: ~StringDirectAttribute(); bool findEnum(const char * value, EnumHandle & e) const override; std::vector<EnumHandle> findFoldedEnums(const char *) const override; - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { - for (size_t i(0); i < sz; i++) { - e[i] = v[i]; - } - } void onCommit() override; void onUpdateStat() override { } bool addDoc(DocId & ) override; diff --git a/searchlib/src/vespa/searchlib/attribute/enumattribute.h b/searchlib/src/vespa/searchlib/attribute/enumattribute.h index 26c70d90cfa..993267f79a6 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/enumattribute.h @@ -56,7 +56,6 @@ protected: const EnumStore & getEnumStore() const { return _enumStore; } const EnumStoreBase * getEnumStoreBase() const override { return &_enumStore; } - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { _enumStore.getEnumValue(v, e, sz); } EnumType getFromEnum(EnumHandle e) const override { return _enumStore.getValue(e); } void fillPostings(LoadedVector & loaded) override { (void) loaded; } diff --git a/searchlib/src/vespa/searchlib/attribute/enumstorebase.cpp b/searchlib/src/vespa/searchlib/attribute/enumstorebase.cpp index 61d862b6c4f..142883e54d6 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumstorebase.cpp +++ b/searchlib/src/vespa/searchlib/attribute/enumstorebase.cpp @@ -136,14 +136,6 @@ EnumStoreBase::getAddressSpaceUsage() const } void -EnumStoreBase::getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const -{ - for(size_t i(0); i < sz; i++) { - e[i] = getEnum(Index(v[i])); - } -} - -void EnumStoreBase::transferHoldLists(generation_t generation) { _enumDict->onTransferHoldLists(generation); diff --git a/searchlib/src/vespa/searchlib/attribute/enumstorebase.h b/searchlib/src/vespa/searchlib/attribute/enumstorebase.h index 9bea2a568e1..9fb91169309 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumstorebase.h +++ b/searchlib/src/vespa/searchlib/attribute/enumstorebase.h @@ -273,7 +273,6 @@ public: size_t getMaxEnumOffset() const { return _store.getBufferState(_store.getActiveBufferId(TYPE_ID)).size(); } - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const; uint32_t getRefCount(Index idx) const { return getEntryBase(idx).getRefCount(); } uint32_t getEnum(Index idx) const { return getEntryBase(idx).getEnum(); } void incRefCount(Index idx) { getEntryBase(idx).incRefCount(); } diff --git a/searchlib/src/vespa/searchlib/attribute/multienumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multienumattribute.hpp index 158343ef7c0..0fd40ab027b 100644 --- a/searchlib/src/vespa/searchlib/attribute/multienumattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multienumattribute.hpp @@ -7,8 +7,6 @@ #include "multienumattributesaver.h" #include "load_utils.h" -#include <stdexcept> - namespace search { template <typename B, typename M> @@ -199,18 +197,10 @@ template <typename B, typename M> std::unique_ptr<AttributeSaver> MultiValueEnumAttribute<B, M>::onInitSave(vespalib::stringref fileName) { - { - this->logEnumStoreEvent("reenumerate", "drain"); - EnumModifier enumGuard(this->getEnumModifier()); - this->logEnumStoreEvent("reenumerate", "start"); - this->_enumStore.reEnumerate(); - } - this->logEnumStoreEvent("reenumerate", "complete"); - vespalib::GenerationHandler::Guard guard(this->getGenerationHandler(). - takeGuard()); + this->_enumStore.reEnumerate(); + vespalib::GenerationHandler::Guard guard(this->getGenerationHandler().takeGuard()); return std::make_unique<MultiValueEnumAttributeSaver<WeightedIndex>> - (std::move(guard), this->createAttributeHeader(fileName), this->_mvMapping, - this->_enumStore); + (std::move(guard), this->createAttributeHeader(fileName), this->_mvMapping, this->_enumStore); } } // namespace search diff --git a/searchlib/src/vespa/searchlib/attribute/multinumericattribute.h b/searchlib/src/vespa/searchlib/attribute/multinumericattribute.h index bea676ff0c3..4b951fd7ceb 100644 --- a/searchlib/src/vespa/searchlib/attribute/multinumericattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/multinumericattribute.h @@ -43,12 +43,6 @@ private: T getFromEnum(EnumHandle e) const override; bool findEnum(T value, EnumHandle & e) const override; - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { - (void) v; - (void) e; - (void) sz; - } - protected: typedef typename B::generation_t generation_t; diff --git a/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.cpp b/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.cpp index 1dc95c42de8..e9743e3e86d 100644 --- a/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.cpp +++ b/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.cpp @@ -139,11 +139,6 @@ NotImplementedAttribute::getEnum(DocId) const { return 0; } -void -NotImplementedAttribute::getEnumValue(const EnumHandle *, uint32_t *, uint32_t) const { - notImplemented(); -} - bool NotImplementedAttribute::addDoc(DocId &) { notImplemented(); diff --git a/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.h b/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.h index cbd2ff162b2..4552a24ec2e 100644 --- a/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.h +++ b/searchlib/src/vespa/searchlib/attribute/not_implemented_attribute.h @@ -33,7 +33,6 @@ struct NotImplementedAttribute : AttributeVector { uint32_t clearDoc(DocId) override; int64_t getDefaultValue() const override; uint32_t getEnum(DocId) const override; - void getEnumValue(const EnumHandle *, uint32_t *, uint32_t) const override; bool addDoc(DocId &) override; void onAddDocs(DocId lidLimit) override; diff --git a/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp index c7299cd71d9..cc9b0346690 100644 --- a/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp @@ -309,15 +309,8 @@ template <typename B> std::unique_ptr<AttributeSaver> SingleValueEnumAttribute<B>::onInitSave(vespalib::stringref fileName) { - { - this->logEnumStoreEvent("reenumerate", "drain"); - EnumModifier enumGuard(this->getEnumModifier()); - this->logEnumStoreEvent("reenumerate", "start"); - this->_enumStore.reEnumerate(); - } - this->logEnumStoreEvent("reenumerate", "complete"); - vespalib::GenerationHandler::Guard guard(this->getGenerationHandler(). - takeGuard()); + this->_enumStore.reEnumerate(); + vespalib::GenerationHandler::Guard guard(this->getGenerationHandler().takeGuard()); return std::make_unique<SingleValueEnumAttributeSaver> (std::move(guard), this->createAttributeHeader(fileName), diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.h b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.h index 06d1068b21a..81fda8b92fc 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.h @@ -125,11 +125,6 @@ public: largeint_t getInt(DocId doc) const override { return static_cast<largeint_t>(getFast(doc)); } - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { - (void) v; - (void) e; - (void) sz; - } double getFloat(DocId doc) const override { return static_cast<double>(_data[doc]); } diff --git a/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.h b/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.h index f5f666bd89f..d5b65da08fa 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.h @@ -141,11 +141,6 @@ public: largeint_t getInt(DocId doc) const override { return static_cast<largeint_t>(getFast(doc)); } - void getEnumValue(const EnumHandle * v, uint32_t *e, uint32_t sz) const override { - (void) v; - (void) e; - (void) sz; - } double getFloat(DocId doc) const override { return static_cast<double>(getFast(doc)); } diff --git a/searchlib/src/vespa/searchlib/engine/transportserver.cpp b/searchlib/src/vespa/searchlib/engine/transportserver.cpp index c5e59024c31..bc739a7bf48 100644 --- a/searchlib/src/vespa/searchlib/engine/transportserver.cpp +++ b/searchlib/src/vespa/searchlib/engine/transportserver.cpp @@ -7,6 +7,7 @@ #include <vespa/fnet/connection.h> #include <vespa/fnet/connector.h> #include <vespa/fnet/iexecutable.h> +#include <vespa/vespalib/net/crypto_engine.h> #include <vespa/log/log.h> LOG_SETUP(".engine.transportserver"); @@ -358,7 +359,7 @@ TransportServer::TransportServer(SearchServer &searchServer, : _searchServer(searchServer), _docsumServer(docsumServer), _monitorServer(monitorServer), - _transport(), + _transport(std::make_shared<vespalib::NullCryptoEngine>(), 1), // disable encryption _ready(false), _failed(false), _doListen(true), diff --git a/searchlib/src/vespa/searchlib/expression/CMakeLists.txt b/searchlib/src/vespa/searchlib/expression/CMakeLists.txt index 1b7a26bf621..944bc6f63df 100644 --- a/searchlib/src/vespa/searchlib/expression/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/expression/CMakeLists.txt @@ -1,6 +1,7 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_library(searchlib_expression OBJECT SOURCES + attribute_keyed_node.cpp attributenode.cpp attributeresult.cpp enumattributeresult.cpp diff --git a/searchlib/src/vespa/searchlib/expression/attribute_keyed_node.cpp b/searchlib/src/vespa/searchlib/expression/attribute_keyed_node.cpp new file mode 100644 index 00000000000..da6ed363b17 --- /dev/null +++ b/searchlib/src/vespa/searchlib/expression/attribute_keyed_node.cpp @@ -0,0 +1,412 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "attribute_keyed_node.h" +#include <vespa/vespalib/stllike/asciistream.h> +#include <vespa/vespalib/util/exceptions.h> +#include <vespa/searchcommon/attribute/attributecontent.h> +#include <vespa/searchcommon/attribute/iattributecontext.h> +#include <vespa/searchcommon/common/undefinedvalues.h> + +using search::attribute::AttributeContent; +using search::attribute::IAttributeVector; +using search::attribute::BasicType; +using search::attribute::getUndefined; +using EnumHandle = IAttributeVector::EnumHandle; + +namespace search::expression { + +class AttributeKeyedNode::KeyHandler +{ +protected: + const IAttributeVector &_attribute; + + KeyHandler(const IAttributeVector &attribute) + : _attribute(attribute) + { + } +public: + static uint32_t noKeyIdx() { return std::numeric_limits<uint32_t>::max(); } + virtual ~KeyHandler() = default; + virtual uint32_t handle(DocId docId) = 0; +}; + +namespace { + +vespalib::string indirectKeyMarker("attribute("); + +class BadKeyHandler : public AttributeKeyedNode::KeyHandler +{ +public: + BadKeyHandler(const IAttributeVector &attribute) + : KeyHandler(attribute) + { + } + uint32_t handle(DocId) override { return noKeyIdx(); } +}; + +template <typename KeyType> +KeyType convertKey(const IAttributeVector &, const vespalib::string &key) +{ + KeyType ret; + vespalib::asciistream is(key); + is >> ret; + return ret; +} + +template <> +vespalib::string convertKey<vespalib::string>(const IAttributeVector &, const vespalib::string &key) +{ + return key; +} + +template <> +EnumHandle convertKey<EnumHandle>(const IAttributeVector &attribute, const vespalib::string &key) +{ + EnumHandle ret; + if (!attribute.findEnum(key.c_str(), ret)) { + ret = EnumHandle(); + } + return ret; +} + +template <typename T, typename KeyType = T> +class KeyHandlerT : public AttributeKeyedNode::KeyHandler +{ + AttributeContent<T> _keys; + KeyType _key; + +public: + KeyHandlerT(const IAttributeVector &attribute, const vespalib::string &key) + : KeyHandler(attribute), + _keys(), + _key(convertKey<KeyType>(attribute, key)) + { + } + ~KeyHandlerT() override; + uint32_t handle(DocId docId) override { + _keys.fill(_attribute, docId); + for (uint32_t i = 0; i < _keys.size(); ++i) { + if (_key == _keys[i]) { + return i; + } + } + return noKeyIdx(); + } +}; + +template <typename T, typename KeyType> +KeyHandlerT<T,KeyType>::~KeyHandlerT() +{ +} + +using IntegerKeyHandler = KeyHandlerT<IAttributeVector::largeint_t>; +using FloatKeyHandler = KeyHandlerT<double>; +using StringKeyHandler = KeyHandlerT<const char *, vespalib::string>; +using EnumKeyHandler = KeyHandlerT<EnumHandle>; + +template <typename T> +bool +matchingKey(T lhs, T rhs) +{ + return lhs == rhs; +} + +template <> +bool +matchingKey<const char *>(const char *lhs, const char *rhs) +{ + return (strcmp(lhs, rhs) == 0); +} + +template <typename T> +class IndirectKeyHandlerT : public AttributeKeyedNode::KeyHandler +{ + const IAttributeVector &_keySourceAttribute; + AttributeContent<T> _keys; + +public: + IndirectKeyHandlerT(const IAttributeVector &attribute, const IAttributeVector &keySourceAttribute) + : KeyHandler(attribute), + _keySourceAttribute(keySourceAttribute), + _keys() + { + } + ~IndirectKeyHandlerT() override; + uint32_t handle(DocId docId) override { + T key = T(); + _keySourceAttribute.get(docId, &key, 1); + _keys.fill(_attribute, docId); + for (uint32_t i = 0; i < _keys.size(); ++i) { + if (matchingKey(key, _keys[i])) { + return i; + } + } + return noKeyIdx(); + } +}; + +template <typename T> +IndirectKeyHandlerT<T>::~IndirectKeyHandlerT() +{ +} + +using IndirectIntegerKeyHandler = IndirectKeyHandlerT<IAttributeVector::largeint_t>; +using IndirectFloatKeyHandler = IndirectKeyHandlerT<double>; +using IndirectStringKeyHandler = IndirectKeyHandlerT<const char *>; + +class ValueHandler : public AttributeNode::Handler +{ +protected: + std::unique_ptr<AttributeKeyedNode::KeyHandler> _keyHandler; + const IAttributeVector &_attribute; + ValueHandler(std::unique_ptr<AttributeKeyedNode::KeyHandler> keyHandler, const IAttributeVector &attribute) + : _keyHandler(std::move(keyHandler)), + _attribute(attribute) + { + } +}; + +template <typename T, typename ResultNodeType> +class ValueHandlerT : public ValueHandler +{ + AttributeContent<T> _values; + ResultNodeType &_result; + T _undefinedValue; +public: + ValueHandlerT(std::unique_ptr<AttributeKeyedNode::KeyHandler> keyHandler, const IAttributeVector &attribute, ResultNodeType &result, T undefinedValue) + : ValueHandler(std::move(keyHandler), attribute), + _values(), + _result(result), + _undefinedValue(undefinedValue) + { + } + void handle(const AttributeResult & r) override { + uint32_t docId = r.getDocId(); + uint32_t keyIdx = _keyHandler->handle(docId); + if (keyIdx != AttributeKeyedNode::KeyHandler::noKeyIdx()) { + _values.fill(_attribute, docId); + if (keyIdx < _values.size()) { + _result = _values[keyIdx]; + return; + } + } + _result = _undefinedValue; + } +}; + +template <typename ResultNodeType> +using IntegerValueHandler = ValueHandlerT<IAttributeVector::largeint_t, ResultNodeType>; +using FloatValueHandler = ValueHandlerT<double, FloatResultNode>; +using StringValueHandler = ValueHandlerT<const char *, StringResultNode>; +using EnumValueHandler = ValueHandlerT<EnumHandle, EnumResultNode>; + +const IAttributeVector *findAttribute(const search::attribute::IAttributeContext &attrCtx, bool useEnumOptimization, const vespalib::string &name) +{ + const IAttributeVector *attribute = useEnumOptimization ? attrCtx.getAttributeStableEnum(name) : attrCtx.getAttribute(name); + if (attribute == nullptr) { + throw std::runtime_error(vespalib::make_string("Failed locating attribute vector '%s'", name.c_str())); + } + return attribute; +} + +IAttributeVector::largeint_t getUndefinedValue(BasicType::Type basicType) +{ + switch (basicType) { + case BasicType::INT8: + return getUndefined<int8_t>(); + case BasicType::INT16: + return getUndefined<int16_t>(); + case BasicType::INT32: + return getUndefined<int32_t>(); + case BasicType::INT64: + return getUndefined<int64_t>(); + break; + default: + return 0; + } +} + +} + +AttributeKeyedNode::AttributeKeyedNode() + : AttributeNode(), + _keyAttributeName(), + _valueAttributeName(), + _key(), + _keySourceAttributeName(), + _keyAttribute(nullptr), + _keySourceAttribute(nullptr) +{ +} + +AttributeKeyedNode::AttributeKeyedNode(const AttributeKeyedNode &) = default; + +AttributeKeyedNode::AttributeKeyedNode(vespalib::stringref name) + : AttributeNode(name), + _keyAttributeName(), + _valueAttributeName(), + _key(), + _keySourceAttributeName(), + _keyAttribute(nullptr), + _keySourceAttribute(nullptr) +{ + setupAttributeNames(); +} + +AttributeKeyedNode::~AttributeKeyedNode() = default; + +AttributeKeyedNode & +AttributeKeyedNode::operator=(const AttributeKeyedNode &rhs) = default; + +void +AttributeKeyedNode::setupAttributeNames() +{ + vespalib::asciistream keyName; + vespalib::asciistream valueName; + auto leftBracePos = _attributeName.find('{'); + auto baseName = _attributeName.substr(0, leftBracePos); + auto rightBracePos = _attributeName.rfind('}'); + keyName << baseName << ".key"; + valueName << baseName << ".value" << _attributeName.substr(rightBracePos + 1); + _keyAttributeName = keyName.str(); + _valueAttributeName = valueName.str(); + if (rightBracePos != vespalib::string::npos && rightBracePos > leftBracePos) { + if (_attributeName[leftBracePos + 1] == '"' && _attributeName[rightBracePos - 1] == '"') { + _key = _attributeName.substr(leftBracePos + 2, rightBracePos - leftBracePos - 3); + } else if (_attributeName.substr(leftBracePos + 1, indirectKeyMarker.size()) == indirectKeyMarker && _attributeName[rightBracePos - 1] == ')') { + auto startPos = leftBracePos + 1 + indirectKeyMarker.size(); + _keySourceAttributeName = _attributeName.substr(startPos, rightBracePos - 1 - startPos); + } + } +} + +template <typename ResultNodeType> +void +AttributeKeyedNode::prepareIntValues(std::unique_ptr<KeyHandler> keyHandler, const IAttributeVector &attribute, IAttributeVector::largeint_t undefinedValue) +{ + auto resultNode = std::make_unique<ResultNodeType>(); + _handler = std::make_unique<IntegerValueHandler<ResultNodeType>>(std::move(keyHandler), attribute, *resultNode, undefinedValue); + setResultType(std::move(resultNode)); +} + +std::unique_ptr<AttributeKeyedNode::KeyHandler> +AttributeKeyedNode::makeKeyHandlerHelper() +{ + const IAttributeVector &attribute = *_keyAttribute; + if (_keySourceAttribute != nullptr) { + const IAttributeVector &keySourceAttribute = *_keySourceAttribute; + if (attribute.isIntegerType() && keySourceAttribute.isIntegerType()) { + return std::make_unique<IndirectIntegerKeyHandler>(attribute, keySourceAttribute); + } else if (attribute.isFloatingPointType() && keySourceAttribute.isFloatingPointType()) { + return std::make_unique<IndirectFloatKeyHandler>(attribute, keySourceAttribute); + } else if (attribute.isStringType() && keySourceAttribute.isStringType()) { + return std::make_unique<IndirectStringKeyHandler>(attribute, keySourceAttribute); + } else { + return std::make_unique<BadKeyHandler>(attribute); + } + } + if (attribute.hasEnum() && _useEnumOptimization) { + return std::make_unique<EnumKeyHandler>(attribute, _key); + } else if (attribute.isIntegerType()) { + return std::make_unique<IntegerKeyHandler>(attribute, _key); + } else if (attribute.isFloatingPointType()) { + return std::make_unique<FloatKeyHandler>(attribute, _key); + } else if (attribute.isStringType()) { + return std::make_unique<StringKeyHandler>(attribute, _key); + } else { + return std::make_unique<BadKeyHandler>(attribute); + } +} + +std::unique_ptr<AttributeKeyedNode::KeyHandler> +AttributeKeyedNode::makeKeyHandler() +{ + try { + return makeKeyHandlerHelper(); + } catch (const vespalib::IllegalArgumentException &) { + return std::make_unique<BadKeyHandler>(*_keyAttribute); + } +} + +void +AttributeKeyedNode::onPrepare(bool preserveAccurateTypes) +{ + auto keyHandler = makeKeyHandler(); + const IAttributeVector * attribute = _scratchResult->getAttribute(); + if (attribute != nullptr) { + BasicType::Type basicType = attribute->getBasicType(); + if (attribute->isIntegerType()) { + IAttributeVector::largeint_t undefinedValue = getUndefinedValue(basicType); + if (preserveAccurateTypes) { + switch (basicType) { + case BasicType::INT8: + prepareIntValues<Int8ResultNode>(std::move(keyHandler), *attribute, undefinedValue); + break; + case BasicType::INT16: + prepareIntValues<Int16ResultNode>(std::move(keyHandler), *attribute, undefinedValue); + break; + case BasicType::INT32: + prepareIntValues<Int32ResultNode>(std::move(keyHandler), *attribute, undefinedValue); + break; + case BasicType::INT64: + prepareIntValues<Int64ResultNode>(std::move(keyHandler), *attribute, undefinedValue); + break; + default: + throw std::runtime_error("This is no valid integer attribute " + attribute->getName()); + break; + } + } else { + prepareIntValues<Int64ResultNode>(std::move(keyHandler), *attribute, undefinedValue); + } + } else if (attribute->isFloatingPointType()) { + auto resultNode = std::make_unique<FloatResultNode>(); + _handler = std::make_unique<FloatValueHandler>(std::move(keyHandler), *attribute, *resultNode, getUndefined<double>()); + setResultType(std::move(resultNode)); + } else if (attribute->isStringType()) { + if (_useEnumOptimization) { + auto resultNode = std::make_unique<EnumResultNode>(); + _handler = std::make_unique<EnumValueHandler>(std::move(keyHandler), *attribute, *resultNode, EnumHandle()); + setResultType(std::move(resultNode)); + } else { + auto resultNode = std::make_unique<StringResultNode>(); + _handler = std::make_unique<StringValueHandler>(std::move(keyHandler), *attribute, *resultNode, ""); + setResultType(std::move(resultNode)); + } + } else { + throw std::runtime_error(vespalib::make_string("Can not deduce correct resultclass for attribute vector '%s'", + attribute->getName().c_str())); + } + } +} + +void +AttributeKeyedNode::cleanup() +{ + _keyAttribute = nullptr; + _keySourceAttribute = nullptr; + AttributeNode::cleanup(); +} + +void +AttributeKeyedNode::wireAttributes(const search::attribute::IAttributeContext &attrCtx) +{ + auto valueAttribute = findAttribute(attrCtx, _useEnumOptimization, _valueAttributeName); + _hasMultiValue = false; + _scratchResult = std::make_unique<AttributeResult>(valueAttribute, 0); + _keyAttribute = findAttribute(attrCtx, _useEnumOptimization, _keyAttributeName); + if (!_keySourceAttributeName.empty()) { + _keySourceAttribute = findAttribute(attrCtx, false, _keySourceAttributeName); + } +} + +void +AttributeKeyedNode::visitMembers(vespalib::ObjectVisitor &visitor) const +{ + AttributeNode::visitMembers(visitor); + visit(visitor, "keyAttributeName", _keyAttributeName); + visit(visitor, "keySourceAttributeName", _keySourceAttributeName); + visit(visitor, "valueAttributeName", _valueAttributeName); + visit(visitor, "key", _key); +} + +} diff --git a/searchlib/src/vespa/searchlib/expression/attribute_keyed_node.h b/searchlib/src/vespa/searchlib/expression/attribute_keyed_node.h new file mode 100644 index 00000000000..e2cf8943aae --- /dev/null +++ b/searchlib/src/vespa/searchlib/expression/attribute_keyed_node.h @@ -0,0 +1,45 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "attributenode.h" + +namespace search::expression { + +/** + * Extract map value from attribute for the map key specified in the + * grouping expression. + */ +class AttributeKeyedNode : public AttributeNode +{ +public: + using IAttributeVector = search::attribute::IAttributeVector; + class KeyHandler; +private: + vespalib::string _keyAttributeName; + vespalib::string _valueAttributeName; + vespalib::string _key; + vespalib::string _keySourceAttributeName; + const IAttributeVector *_keyAttribute; + const IAttributeVector *_keySourceAttribute; + + void setupAttributeNames(); + template <typename ResultNodeType> + void prepareIntValues(std::unique_ptr<KeyHandler> keyHandler, const IAttributeVector &attribute, IAttributeVector::largeint_t undefinedValue); + std::unique_ptr<KeyHandler> makeKeyHandlerHelper(); + std::unique_ptr<KeyHandler> makeKeyHandler(); + void cleanup() override; + void wireAttributes(const search::attribute::IAttributeContext & attrCtx) override; + void onPrepare(bool preserveAccurateTypes) override; +public: + AttributeKeyedNode(); + AttributeKeyedNode(vespalib::stringref name); + AttributeKeyedNode(const AttributeKeyedNode &); + AttributeKeyedNode(AttributeKeyedNode &&) = delete; + ~AttributeKeyedNode() override; + AttributeKeyedNode &operator=(const AttributeKeyedNode &rhs); + AttributeKeyedNode &operator=(AttributeKeyedNode &&rhs) = delete; + void visitMembers(vespalib::ObjectVisitor &visitor) const override; + bool isKeyed() const override { return true; } +}; + +} diff --git a/searchlib/src/vespa/searchlib/expression/attributenode.h b/searchlib/src/vespa/searchlib/expression/attributenode.h index 3cbccd32e60..e12b5490955 100644 --- a/searchlib/src/vespa/searchlib/expression/attributenode.h +++ b/searchlib/src/vespa/searchlib/expression/attributenode.h @@ -55,7 +55,8 @@ public: void useEnumOptimization(bool use=true) { _useEnumOptimization = use; } bool hasMultiValue() const { return _hasMultiValue; } -protected: + virtual bool isKeyed() const { return false; } +public: class Handler { public: @@ -68,7 +69,7 @@ private: class StringHandler; class EnumHandler; protected: - void cleanup(); + virtual void cleanup(); void wireAttributes(const search::attribute::IAttributeContext & attrCtx) override; void onPrepare(bool preserveAccurateTypes) override; bool onExecute() const override; diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp index b05d5fb4e54..5cd6c479d24 100644 --- a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp +++ b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp @@ -399,7 +399,13 @@ const uint32_t HeapSize::DEFAULT_VALUE(100); uint32_t HeapSize::lookup(const Properties &props) { - return lookupUint32(props, NAME, DEFAULT_VALUE); + return lookup(props, DEFAULT_VALUE); +} + +uint32_t +HeapSize::lookup(const Properties &props, uint32_t defaultValue) +{ + return lookupUint32(props, NAME, defaultValue); } const vespalib::string ArraySize::NAME("vespa.hitcollector.arraysize"); diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.h b/searchlib/src/vespa/searchlib/fef/indexproperties.h index 68bed502121..8b78e347a90 100644 --- a/searchlib/src/vespa/searchlib/fef/indexproperties.h +++ b/searchlib/src/vespa/searchlib/fef/indexproperties.h @@ -320,6 +320,7 @@ namespace hitcollector { static const vespalib::string NAME; static const uint32_t DEFAULT_VALUE; static uint32_t lookup(const Properties &props); + static uint32_t lookup(const Properties &props, uint32_t defaultValue); }; /** diff --git a/searchlib/src/vespa/searchlib/transactionlog/domain.cpp b/searchlib/src/vespa/searchlib/transactionlog/domain.cpp index 1caff132779..fc9518ccf1b 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/domain.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/domain.cpp @@ -60,13 +60,13 @@ Domain::Domain(const string &domainName, const string & baseDir, Executor & comm } _sessionExecutor.sync(); if (_parts.empty() || _parts.crbegin()->second->isClosed()) { - _parts[lastPart].reset(new DomainPart(_name, dir(), lastPart, _defaultCrcType, _fileHeaderContext, false)); + _parts[lastPart] = std::make_shared<DomainPart>(_name, dir(), lastPart, _defaultCrcType, _fileHeaderContext, false); vespalib::File::sync(dir()); } } void Domain::addPart(int64_t partId, bool isLastPart) { - DomainPart::SP dp(new DomainPart(_name, dir(), partId, _defaultCrcType, _fileHeaderContext, isLastPart)); + auto dp = std::make_shared<DomainPart>(_name, dir(), partId, _defaultCrcType, _fileHeaderContext, isLastPart); if (dp->size() == 0) { // Only last domain part is allowed to be truncated down to // empty size. @@ -199,7 +199,7 @@ Domain::triggerSyncNow() if (!_pendingSync) { _pendingSync = true; DomainPart::SP dp(_parts.rbegin()->second); - _commitExecutor.execute(Sync::UP(new Sync(_syncMonitor, dp, _pendingSync))); + _commitExecutor.execute(std::make_unique<Sync>(_syncMonitor, dp, _pendingSync)); } } @@ -290,7 +290,7 @@ void Domain::commit(const Packet & packet) triggerSyncNow(); waitPendingSync(_syncMonitor, _pendingSync); dp->close(); - dp.reset(new DomainPart(_name, dir(), entry.serial(), _defaultCrcType, _fileHeaderContext, false)); + dp = std::make_shared<DomainPart>(_name, dir(), entry.serial(), _defaultCrcType, _fileHeaderContext, false); { LockGuard guard(_lock); _parts[entry.serial()] = dp; @@ -322,15 +322,16 @@ bool Domain::erase(SerialNum to) } int Domain::visit(const Domain::SP & domain, SerialNum from, SerialNum to, - FRT_Supervisor & supervisor, FNET_Connection *conn) + std::unique_ptr<Session::Destination> dest) { assert(this == domain.get()); cleanSessions(); SerialNumRange range(from, to); - Session * session = new Session(_sessionId++, range, domain, supervisor, conn); + auto session = std::make_shared<Session>(_sessionId++, range, domain, std::move(dest)); + int id = session->id(); LockGuard guard(_sessionLock); - _sessions[session->id()] = Session::SP(session); - return session->id(); + _sessions[id] = std::move(session); + return id; } int Domain::startSession(int sessionId) diff --git a/searchlib/src/vespa/searchlib/transactionlog/domain.h b/searchlib/src/vespa/searchlib/transactionlog/domain.h index c1ff9157a6f..c0ee484926c 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/domain.h +++ b/searchlib/src/vespa/searchlib/transactionlog/domain.h @@ -51,7 +51,7 @@ public: bool erase(SerialNum to); void commit(const Packet & packet); - int visit(const Domain::SP & self, SerialNum from, SerialNum to, FRT_Supervisor & supervisor, FNET_Connection *conn); + int visit(const Domain::SP & self, SerialNum from, SerialNum to, std::unique_ptr<Session::Destination> dest); SerialNum begin() const; SerialNum end() const; diff --git a/searchlib/src/vespa/searchlib/transactionlog/session.cpp b/searchlib/src/vespa/searchlib/transactionlog/session.cpp index cbcbc68fdff..e703c32484f 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/session.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/session.cpp @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "session.h" #include "domain.h" -#include <vespa/fnet/frt/supervisor.h> #include <vespa/fastlib/io/bufferedfile.h> #include <vespa/log/log.h> @@ -11,14 +10,10 @@ using vespalib::LockGuard; namespace search::transactionlog { -namespace { - const double NEVER(-1.0); -} - vespalib::Executor::Task::UP Session::createTask(const Session::SP & session) { - return Task::UP(new VisitTask(session)); + return std::make_unique<VisitTask>(session); } Session::VisitTask::VisitTask(const Session::SP & session) @@ -86,7 +81,7 @@ Session::visitOnly() } bool Session::finished() const { - return _finished || (_connection->GetState() != FNET_Connection::FNET_CONNECTED); + return _finished || ! _destination->connected(); } void @@ -99,95 +94,31 @@ Session::finalize() _finished = true; } -int32_t -Session::rpc(FRT_RPCRequest * req) -{ - int32_t retval(-7); - LOG(debug, "rpc %s starting.", req->GetMethodName()); - FRT_Supervisor::InvokeSync(_supervisor.GetTransport(), _connection, req, NEVER); - if (req->GetErrorCode() == FRTE_NO_ERROR) { - retval = (req->GetReturn()->GetValue(0)._intval32); - LOG(debug, "rpc %s = %d\n", req->GetMethodName(), retval); - } else if (req->GetErrorCode() == FRTE_RPC_TIMEOUT) { - LOG(warning, "rpc %s timed out. Will allow to continue: error(%d): %s\n", req->GetMethodName(), req->GetErrorCode(), req->GetErrorMessage()); - retval = -req->GetErrorCode(); - } else { - if (req->GetErrorCode() != FRTE_RPC_CONNECTION) { - LOG(warning, "rpc %s: error(%d): %s\n", req->GetMethodName(), req->GetErrorCode(), req->GetErrorMessage()); - } - retval = -req->GetErrorCode(); - _ok = false; - } - return retval; -} - -void -Session::RequestDone(FRT_RPCRequest * req) -{ - _ok = (req->GetErrorCode() == FRTE_NO_ERROR); - if (req->GetErrorCode() != FRTE_NO_ERROR) { - LOG(warning, "rpcAsync failed %s: error(%d): %s\n", req->GetMethodName(), req->GetErrorCode(), req->GetErrorMessage()); - } else { - int32_t retval = req->GetReturn()->GetValue(0)._intval32; - if (retval != RPC::OK) { - LOG(error, "Return value != OK in RequestDone for method '%s'", req->GetMethodName()); - } - } - req->SubRef(); -} - Session::Session(int sId, const SerialNumRange & r, const Domain::SP & d, - FRT_Supervisor & supervisor, FNET_Connection *conn) : - _supervisor(supervisor), - _connection(conn), + std::unique_ptr<Destination> destination) : + _destination(std::move(destination)), _domain(d), _range(r), _id(sId), - _ok(true), _visitRunning(false), _inSync(false), _finished(false), _startTime() { - _connection->AddRef(); } -Session::~Session() -{ - _connection->SubRef(); -} +Session::~Session() = default; bool Session::send(const Packet & packet) { - FRT_RPCRequest *req = _supervisor.AllocRPCRequest(); - req->SetMethodName("visitCallback"); - req->GetParams()->AddString(_domain->name().c_str()); - req->GetParams()->AddInt32(id()); - req->GetParams()->AddData(packet.getHandle().c_str(), packet.getHandle().size()); - return send(req); -} - -bool -Session::send(FRT_RPCRequest * req) -{ - int32_t retval = rpc(req); - if ( ! ((retval == RPC::OK) || (retval == FRTE_RPC_CONNECTION)) ) { - LOG(error, "Return value != OK(%d) in send for method 'visitCallback'.", retval); - } - req->SubRef(); - - return (retval == RPC::OK); + return _destination->send(_id, _domain->name(), packet); } bool Session::sendDone() { - FRT_RPCRequest *req = _supervisor.AllocRPCRequest(); - req->SetMethodName("eofCallback"); - req->GetParams()->AddString(_domain->name().c_str()); - req->GetParams()->AddInt32(id()); - bool retval(send(req)); + bool retval = _destination->sendDone(_id, _domain->name()); _inSync = true; return retval; } diff --git a/searchlib/src/vespa/searchlib/transactionlog/session.h b/searchlib/src/vespa/searchlib/transactionlog/session.h index 29038ec5290..bf35d83c000 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/session.h +++ b/searchlib/src/vespa/searchlib/transactionlog/session.h @@ -4,9 +4,9 @@ #include "common.h" #include <vespa/vespalib/util/executor.h> #include <vespa/vespalib/util/sync.h> -#include <vespa/fnet/frt/invoker.h> #include <chrono> #include <deque> +#include <atomic> class FastOS_FileInterface; @@ -16,22 +16,29 @@ class Domain; class DomainPart; using DomainSP = std::shared_ptr<Domain>; -class Session : public FRT_IRequestWait +class Session { private: using Task = vespalib::Executor::Task; using time_point = std::chrono::time_point<std::chrono::steady_clock>; public: + class Destination { + public: + virtual ~Destination() {} + virtual bool send(int32_t id, const vespalib::string & domain, const Packet & packet) = 0; + virtual bool sendDone(int32_t id, const vespalib::string & domain) = 0; + virtual bool connected() const = 0; + virtual bool ok() const = 0; + }; typedef std::shared_ptr<Session> SP; Session(const Session &) = delete; Session & operator = (const Session &) = delete; - Session(int sId, const SerialNumRange & r, const DomainSP & d, FRT_Supervisor & supervisor, FNET_Connection *conn); + Session(int sId, const SerialNumRange & r, const DomainSP & d, std::unique_ptr<Destination> destination); ~Session(); const SerialNumRange & range() const { return _range; } int id() const { return _id; } bool inSync() const { return _inSync; } - bool ok() const { return _ok; } bool finished() const; static Task::UP createTask(const Session::SP & session); void setStartTime(time_point startTime) { _startTime = startTime; } @@ -47,8 +54,7 @@ private: Session::SP _session; }; - bool send(FRT_RPCRequest * req); - void RequestDone(FRT_RPCRequest *req) override; + bool ok() const { return _destination->ok(); } bool send(const Packet & packet); bool sendDone(); void visit(); @@ -56,17 +62,14 @@ private: void startVisit(); void finalize(); bool visit(FastOS_FileInterface & file, DomainPart & dp) __attribute__((noinline)); - int32_t rpc(FRT_RPCRequest * req); - FRT_Supervisor & _supervisor; - FNET_Connection * _connection; - DomainSP _domain; - SerialNumRange _range; - int _id; - bool _ok; - std::atomic<bool> _visitRunning; - std::atomic<bool> _inSync; - std::atomic<bool> _finished; - time_point _startTime; + std::unique_ptr<Destination> _destination; + DomainSP _domain; + SerialNumRange _range; + int _id; + std::atomic<bool> _visitRunning; + std::atomic<bool> _inSync; + std::atomic<bool> _finished; + time_point _startTime; }; } diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogclient.cpp b/searchlib/src/vespa/searchlib/transactionlog/translogclient.cpp index aa2b558ea0c..767c8b45e10 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogclient.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/translogclient.cpp @@ -11,15 +11,40 @@ LOG_SETUP(".translogclient"); using namespace std::chrono_literals; +VESPA_THREAD_STACK_TAG(translogclient_rpc_callback) + namespace search::transactionlog { namespace { const double NEVER(-1.0); } +namespace { + +struct RpcTask : public vespalib::Executor::Task { + FRT_RPCRequest *req; + std::function<void(FRT_RPCRequest *req)> fun; + RpcTask(FRT_RPCRequest *req_in, std::function<void(FRT_RPCRequest *req)> &&fun_in) + : req(req_in), fun(std::move(fun_in)) {} + void run() override { + fun(req); + req->Return(); + req = nullptr; + } + ~RpcTask() { + if (req != nullptr) { + req->SetError(FRTE_RPC_METHOD_FAILED, "client has been shut down"); + req->Return(); + } + } +}; + +} + using vespalib::LockGuard; TransLogClient::TransLogClient(const vespalib::string & rpcTarget) : + _executor(1, 128 * 1024, translogclient_rpc_callback), _rpcTarget(rpcTarget), _sessions(), _supervisor(std::make_unique<FRT_Supervisor>()), @@ -33,6 +58,7 @@ TransLogClient::TransLogClient(const vespalib::string & rpcTarget) : TransLogClient::~TransLogClient() { disconnect(); + _executor.shutdown().sync(); _supervisor->ShutDown(true); } @@ -139,7 +165,7 @@ void TransLogClient::exportRPC(FRT_Supervisor & supervisor) FRT_ReflectionBuilder rb( & supervisor); //-- Visit Callbacks ----------------------------------------------------------- - rb.DefineMethod("visitCallback", "six", "i", false, FRT_METHOD(TransLogClient::visitCallbackRPC), this); + rb.DefineMethod("visitCallback", "six", "i", FRT_METHOD(TransLogClient::visitCallbackRPC_hook), this); rb.MethodDesc("Will return data asked from a subscriber/visitor."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("session", "Session handle."); @@ -147,14 +173,15 @@ void TransLogClient::exportRPC(FRT_Supervisor & supervisor) rb.ReturnDesc("result", "A resultcode(int) of the operation. Non zero number indicates error."); //-- Visit Callbacks ----------------------------------------------------------- - rb.DefineMethod("eofCallback", "si", "i", false, FRT_METHOD(TransLogClient::eofCallbackRPC), this); + rb.DefineMethod("eofCallback", "si", "i", FRT_METHOD(TransLogClient::eofCallbackRPC_hook), this); rb.MethodDesc("Will tell you that you are done with the visitor."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("session", "Session handle."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Non zero number indicates error."); } -void TransLogClient::visitCallbackRPC(FRT_RPCRequest *req) + +void TransLogClient::do_visitCallbackRPC(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -171,7 +198,7 @@ void TransLogClient::visitCallbackRPC(FRT_RPCRequest *req) LOG(debug, "visitCallback(%s, %d)=%d done", domainName, sessionId, retval); } -void TransLogClient::eofCallbackRPC(FRT_RPCRequest *req) +void TransLogClient::do_eofCallbackRPC(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -188,6 +215,16 @@ void TransLogClient::eofCallbackRPC(FRT_RPCRequest *req) LOG(debug, "eofCallback(%s, %d)=%d done", domainName, sessionId, retval); } +void TransLogClient::visitCallbackRPC_hook(FRT_RPCRequest *req) +{ + _executor.execute(std::make_unique<RpcTask>(req->Detach(), [this](FRT_RPCRequest *x){ do_visitCallbackRPC(x); })); +} + +void TransLogClient::eofCallbackRPC_hook(FRT_RPCRequest *req) +{ + _executor.execute(std::make_unique<RpcTask>(req->Detach(), [this](FRT_RPCRequest *x){ do_eofCallbackRPC(x); })); +} + TransLogClient::Session::Session(const vespalib::string & domain, TransLogClient & tlc) : _tlc(tlc), diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogclient.h b/searchlib/src/vespa/searchlib/transactionlog/translogclient.h index 87901890673..267d6e3b0ed 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogclient.h +++ b/searchlib/src/vespa/searchlib/transactionlog/translogclient.h @@ -5,6 +5,7 @@ #include <vespa/document/util/bytebuffer.h> #include <vespa/vespalib/util/sync.h> #include <vespa/vespalib/util/buffer.h> +#include <vespa/vespalib/util/threadstackexecutor.h> #include <vespa/fnet/frt/invokable.h> #include <map> #include <vector> @@ -96,8 +97,10 @@ public: const vespalib::string &getRPCTarget() const { return _rpcTarget; } private: void exportRPC(FRT_Supervisor & supervisor); - void visitCallbackRPC(FRT_RPCRequest *req); - void eofCallbackRPC(FRT_RPCRequest *req); + void do_visitCallbackRPC(FRT_RPCRequest *req); + void do_eofCallbackRPC(FRT_RPCRequest *req); + void visitCallbackRPC_hook(FRT_RPCRequest *req); + void eofCallbackRPC_hook(FRT_RPCRequest *req); int32_t rpc(FRT_RPCRequest * req); Session * findSession(const vespalib::string & domain, int sessionId); @@ -114,6 +117,7 @@ private: typedef std::map< SessionKey, Session * > SessionMap; + vespalib::ThreadStackExecutor _executor; vespalib::string _rpcTarget; SessionMap _sessions; //Brute force lock for subscriptions. For multithread safety. diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp index 65bb682a389..4b3e7bddb07 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp @@ -4,6 +4,8 @@ #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/fnet/frt/supervisor.h> +#include <vespa/fnet/frt/rpcrequest.h> +#include <vespa/fnet/task.h> #include <fstream> #include <vespa/log/log.h> @@ -26,21 +28,16 @@ class SyncHandler : public FNET_Task SerialNum _syncTo; public: - SyncHandler(FRT_Supervisor *supervisor, - FRT_RPCRequest *req,const Domain::SP &domain, - const TransLogServer::Session::SP &session, - SerialNum syncTo); + SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req,const Domain::SP &domain, + const TransLogServer::Session::SP &session, SerialNum syncTo); ~SyncHandler(); void PerformTask() override; }; -SyncHandler::SyncHandler(FRT_Supervisor *supervisor, - FRT_RPCRequest *req, - const Domain::SP &domain, - const TransLogServer::Session::SP &session, - SerialNum syncTo) +SyncHandler::SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req, const Domain::SP &domain, + const TransLogServer::Session::SP &session, SerialNum syncTo) : FNET_Task(supervisor->GetScheduler()), _req(*req), _domain(domain), @@ -50,9 +47,7 @@ SyncHandler::SyncHandler(FRT_Supervisor *supervisor, } -SyncHandler::~SyncHandler() -{ -} +SyncHandler::~SyncHandler() = default; void @@ -154,14 +149,16 @@ TransLogServer::~TransLogServer() _supervisor->ShutDown(true); } -bool TransLogServer::onStop() +bool +TransLogServer::onStop() { LOG(info, "Stopping TLS"); _reqQ.push(NULL); return true; } -void TransLogServer::run() +void +TransLogServer::run() { FRT_RPCRequest *req(NULL); bool hasPacket(false); @@ -236,7 +233,8 @@ TransLogServer::findDomain(stringref domainName) return domain; } -void TransLogServer::exportRPC(FRT_Supervisor & supervisor) +void +TransLogServer::exportRPC(FRT_Supervisor & supervisor) { _supervisor->SetSessionInitHook(FRT_METHOD(TransLogServer::initSession), this); _supervisor->SetSessionFiniHook(FRT_METHOD(TransLogServer::finiSession), this); @@ -244,32 +242,32 @@ void TransLogServer::exportRPC(FRT_Supervisor & supervisor) FRT_ReflectionBuilder rb( & supervisor); //-- Create Domain ----------------------------------------------------------- - rb.DefineMethod("createDomain", "s", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("createDomain", "s", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Create a new domain."); rb.ParamDesc("name", "The name of the domain."); rb.ReturnDesc("handle", "A handle(int) to the domain. Negative number indicates error."); //-- Delete Domain ----------------------------------------------------------- - rb.DefineMethod("deleteDomain", "s", "is", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("deleteDomain", "s", "is", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Create a new domain."); rb.ParamDesc("name", "The name of the domain."); rb.ReturnDesc("retval", "0 on success. Negative number indicates error."); rb.ReturnDesc("errormsg", "Message describing the error, if any."); //-- Open Domain ----------------------------------------------------------- - rb.DefineMethod("openDomain", "s", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("openDomain", "s", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Open an existing domain."); rb.ParamDesc("name", "The name of the domain."); rb.ReturnDesc("handle", "A handle(int) to the domain. Negative number indicates error."); //-- List Domains ----------------------------------------------------------- - rb.DefineMethod("listDomains", "", "is", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("listDomains", "", "is", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Will return a list of all the domains."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error."); rb.ReturnDesc("domains", "List of all the domains in a newline separated string"); //-- Domain Status ----------------------------------------------------------- - rb.DefineMethod("domainStatus", "s", "illl", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainStatus", "s", "illl", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("This will return key status information about the domain."); rb.ParamDesc("name", "The name of the domain."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error."); @@ -278,7 +276,7 @@ void TransLogServer::exportRPC(FRT_Supervisor & supervisor) rb.ReturnDesc("size", "Number of elements in the log."); //-- Domain Commit ----------------------------------------------------------- - rb.DefineMethod("domainCommit", "sx", "is", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainCommit", "sx", "is", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Will commit the data to the log."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("packet", "The data to commit to the domain."); @@ -286,14 +284,14 @@ void TransLogServer::exportRPC(FRT_Supervisor & supervisor) rb.ReturnDesc("message", "A textual description of the result code."); //-- Domain Prune ----------------------------------------------------------- - rb.DefineMethod("domainPrune", "sl", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainPrune", "sl", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Will erase all operations prior to the serial number."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("to", "Will erase all up and including."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error."); //-- Domain Visit ----------------------------------------------------------- - rb.DefineMethod("domainVisit", "sll", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainVisit", "sll", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("This will create a visitor that return all operations in the range."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("from", "Will return all entries following(not including) <from>."); @@ -301,21 +299,21 @@ void TransLogServer::exportRPC(FRT_Supervisor & supervisor) rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error. Positive number is the sessionid"); //-- Domain Session Run ----------------------------------------------------------- - rb.DefineMethod("domainSessionRun", "si", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainSessionRun", "si", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("This will start the session thread."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("sessionid", "The session identifier."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error."); //-- Domain Session Close ----------------------------------------------------------- - rb.DefineMethod("domainSessionClose", "si", "i", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainSessionClose", "si", "i", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("This will close the session."); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("sessionid", "The session identifier."); rb.ReturnDesc("result", "A resultcode(int) of the operation. Negative number indicates error. 1 means busy -> retry. 0 is OK."); //-- Domain Sync -- - rb.DefineMethod("domainSync", "sl", "il", true, FRT_METHOD(TransLogServer::relayToThreadRPC), this); + rb.DefineMethod("domainSync", "sl", "il", FRT_METHOD(TransLogServer::relayToThreadRPC), this); rb.MethodDesc("Sync domain to given entry"); rb.ParamDesc("name", "The name of the domain."); rb.ParamDesc("syncto", "Entry to sync to"); @@ -325,6 +323,8 @@ void TransLogServer::exportRPC(FRT_Supervisor & supervisor) namespace { +constexpr double NEVER(-1.0); + void writeDomainDir(std::lock_guard<std::mutex> &guard, vespalib::string dir, @@ -344,9 +344,77 @@ writeDomainDir(std::lock_guard<std::mutex> &guard, vespalib::File::sync(dir); } +class RPCDestination : public Session::Destination { +public: + RPCDestination(FRT_Supervisor & supervisor, FNET_Connection * connection) + : _supervisor(supervisor), _connection(connection), _ok(true) + { + _connection->AddRef(); + } + ~RPCDestination() override { _connection->SubRef(); } + + bool ok() const override { + return _ok; + } + + bool send(int32_t id, const vespalib::string & domain, const Packet & packet) override { + FRT_RPCRequest *req = _supervisor.AllocRPCRequest(); + req->SetMethodName("visitCallback"); + req->GetParams()->AddString(domain.c_str()); + req->GetParams()->AddInt32(id); + req->GetParams()->AddData(packet.getHandle().c_str(), packet.getHandle().size()); + return send(req); + } + + bool sendDone(int32_t id, const vespalib::string & domain) override { + FRT_RPCRequest *req = _supervisor.AllocRPCRequest(); + req->SetMethodName("eofCallback"); + req->GetParams()->AddString(domain.c_str()); + req->GetParams()->AddInt32(id); + bool retval(send(req)); + return retval; + } + bool connected() const override { + return (_connection->GetState() <= FNET_Connection::FNET_CONNECTED); + } +private: + bool send(FRT_RPCRequest * req) { + int32_t retval = rpc(req); + if ( ! ((retval == RPC::OK) || (retval == FRTE_RPC_CONNECTION)) ) { + LOG(error, "Return value != OK(%d) in send for method 'visitCallback'.", retval); + } + req->SubRef(); + + return (retval == RPC::OK); + } + int32_t rpc(FRT_RPCRequest * req) { + int32_t retval(-7); + LOG(debug, "rpc %s starting.", req->GetMethodName()); + FRT_Supervisor::InvokeSync(_supervisor.GetTransport(), _connection, req, NEVER); + if (req->GetErrorCode() == FRTE_NO_ERROR) { + retval = (req->GetReturn()->GetValue(0)._intval32); + LOG(debug, "rpc %s = %d\n", req->GetMethodName(), retval); + } else if (req->GetErrorCode() == FRTE_RPC_TIMEOUT) { + LOG(warning, "rpc %s timed out. Will allow to continue: error(%d): %s\n", req->GetMethodName(), req->GetErrorCode(), req->GetErrorMessage()); + retval = -req->GetErrorCode(); + } else { + if (req->GetErrorCode() != FRTE_RPC_CONNECTION) { + LOG(warning, "rpc %s: error(%d): %s\n", req->GetMethodName(), req->GetErrorCode(), req->GetErrorMessage()); + } + retval = -req->GetErrorCode(); + _ok = false; + } + return retval; + } + FRT_Supervisor & _supervisor; + FNET_Connection * _connection; + bool _ok; +}; + } -void TransLogServer::createDomain(FRT_RPCRequest *req) +void +TransLogServer::createDomain(FRT_RPCRequest *req) { uint32_t retval(0); FRT_Values & params = *req->GetParams(); @@ -373,7 +441,8 @@ void TransLogServer::createDomain(FRT_RPCRequest *req) ret.AddInt32(retval); } -void TransLogServer::deleteDomain(FRT_RPCRequest *req) +void +TransLogServer::deleteDomain(FRT_RPCRequest *req) { uint32_t retval(0); vespalib::string msg("ok"); @@ -410,7 +479,8 @@ void TransLogServer::deleteDomain(FRT_RPCRequest *req) ret.AddString(msg.c_str()); } -void TransLogServer::openDomain(FRT_RPCRequest *req) +void +TransLogServer::openDomain(FRT_RPCRequest *req) { uint32_t retval(0); FRT_Values & params = *req->GetParams(); @@ -427,7 +497,8 @@ void TransLogServer::openDomain(FRT_RPCRequest *req) ret.AddInt32(retval); } -void TransLogServer::listDomains(FRT_RPCRequest *req) +void +TransLogServer::listDomains(FRT_RPCRequest *req) { FRT_Values & ret = *req->GetReturn(); LOG(debug, "listDomains()"); @@ -442,7 +513,8 @@ void TransLogServer::listDomains(FRT_RPCRequest *req) ret.AddString(domains.c_str()); } -void TransLogServer::domainStatus(FRT_RPCRequest *req) +void +TransLogServer::domainStatus(FRT_RPCRequest *req) { FRT_Values & params = *req->GetParams(); FRT_Values & ret = *req->GetReturn(); @@ -462,7 +534,8 @@ void TransLogServer::domainStatus(FRT_RPCRequest *req) } } -void TransLogServer::commit(const vespalib::string & domainName, const Packet & packet, DoneCallback done) +void +TransLogServer::commit(const vespalib::string & domainName, const Packet & packet, DoneCallback done) { (void) done; Domain::SP domain(findDomain(domainName)); @@ -473,7 +546,8 @@ void TransLogServer::commit(const vespalib::string & domainName, const Packet & } } -void TransLogServer::domainCommit(FRT_RPCRequest *req) +void +TransLogServer::domainCommit(FRT_RPCRequest *req) { FRT_Values & params = *req->GetParams(); FRT_Values & ret = *req->GetReturn(); @@ -496,7 +570,8 @@ void TransLogServer::domainCommit(FRT_RPCRequest *req) } } -void TransLogServer::domainVisit(FRT_RPCRequest *req) +void +TransLogServer::domainVisit(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -508,12 +583,13 @@ void TransLogServer::domainVisit(FRT_RPCRequest *req) SerialNum from(params[1]._intval64); SerialNum to(params[2]._intval64); LOG(debug, "domainVisit(%s, %" PRIu64 ", %" PRIu64 ")", domainName, from, to); - retval = domain->visit(domain, from, to, *_supervisor, req->GetConnection()); + retval = domain->visit(domain, from, to, std::make_unique<RPCDestination>(*_supervisor, req->GetConnection())); } ret.AddInt32(retval); } -void TransLogServer::domainSessionRun(FRT_RPCRequest *req) +void +TransLogServer::domainSessionRun(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -529,13 +605,15 @@ void TransLogServer::domainSessionRun(FRT_RPCRequest *req) ret.AddInt32(retval); } -void TransLogServer::relayToThreadRPC(FRT_RPCRequest *req) +void +TransLogServer::relayToThreadRPC(FRT_RPCRequest *req) { req->Detach(); _reqQ.push(req); } -void TransLogServer::domainSessionClose(FRT_RPCRequest *req) +void +TransLogServer::domainSessionClose(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -552,7 +630,8 @@ void TransLogServer::domainSessionClose(FRT_RPCRequest *req) ret.AddInt32(retval); } -void TransLogServer::domainPrune(FRT_RPCRequest *req) +void +TransLogServer::domainPrune(FRT_RPCRequest *req) { uint32_t retval(uint32_t(-1)); FRT_Values & params = *req->GetParams(); @@ -572,7 +651,6 @@ void TransLogServer::domainPrune(FRT_RPCRequest *req) ret.AddInt32(retval); } - const TransLogServer::Session::SP & TransLogServer::getSession(FRT_RPCRequest *req) { @@ -582,14 +660,12 @@ TransLogServer::getSession(FRT_RPCRequest *req) return *sessionspp; } - void TransLogServer::initSession(FRT_RPCRequest *req) { req->GetConnection()->SetContext(new Session::SP(new Session())); } - void TransLogServer::finiSession(FRT_RPCRequest *req) { @@ -600,14 +676,12 @@ TransLogServer::finiSession(FRT_RPCRequest *req) delete sessionspp; } - void TransLogServer::downSession(FRT_RPCRequest *req) { getSession(req)->setDown(); } - void TransLogServer::domainSync(FRT_RPCRequest *req) { diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.h b/searchlib/src/vespa/searchlib/transactionlog/translogserver.h index 189be8c38d8..8aedfef6d8d 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.h +++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.h @@ -8,6 +8,9 @@ #include <vespa/fnet/frt/invokable.h> #include <mutex> + +class FRT_Supervisor; + namespace search::common { class FileHeaderContext; } namespace search::transactionlog { diff --git a/slobrok/src/tests/mirrorapi/mirrorapi.cpp b/slobrok/src/tests/mirrorapi/mirrorapi.cpp index 0550bf51b0c..f77dfd80986 100644 --- a/slobrok/src/tests/mirrorapi/mirrorapi.cpp +++ b/slobrok/src/tests/mirrorapi/mirrorapi.cpp @@ -41,7 +41,7 @@ Server::Server(std::string name, int port, std::string slobrokSpec) { FRT_ReflectionBuilder rb(&_orb); //--------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(Server::rpc_listNamesServed), this); rb.MethodDesc("Look up a rpcserver"); rb.ReturnDesc("names", "The rpcserver names on this server"); diff --git a/slobrok/src/tests/oldapi/old.cpp b/slobrok/src/tests/oldapi/old.cpp index 77bca6dfe90..42cec186a08 100644 --- a/slobrok/src/tests/oldapi/old.cpp +++ b/slobrok/src/tests/oldapi/old.cpp @@ -39,7 +39,7 @@ Server::Server(std::string name, int port, std::string slobrokSpec) { FRT_ReflectionBuilder rb(&_orb); //--------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(Server::rpc_listNamesServed), this); rb.MethodDesc("Look up a rpcserver"); rb.ReturnDesc("names", "The rpcserver names on this server"); diff --git a/slobrok/src/tests/standalone/standalone.cpp b/slobrok/src/tests/standalone/standalone.cpp index 63f8b1d2c59..136f8125c8b 100644 --- a/slobrok/src/tests/standalone/standalone.cpp +++ b/slobrok/src/tests/standalone/standalone.cpp @@ -26,7 +26,7 @@ Server::Server(std::string name, int port) { FRT_ReflectionBuilder rb(&_orb); //--------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(Server::rpc_listNamesServed), this); rb.MethodDesc("Look up a rpcserver"); rb.ReturnDesc("names", "The rpcserver names on this server"); diff --git a/slobrok/src/tests/startsome/tstdst.cpp b/slobrok/src/tests/startsome/tstdst.cpp index 44b42e1ff83..4723b3819d7 100644 --- a/slobrok/src/tests/startsome/tstdst.cpp +++ b/slobrok/src/tests/startsome/tstdst.cpp @@ -87,12 +87,12 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) FRT_ReflectionBuilder rb(supervisor); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(RPCHooks::rpc_listNamesServed), this); rb.MethodDesc("Look up a rpcserver"); rb.ReturnDesc("names", "The rpcserver names on this server"); //------------------------------------------------------------------------- - rb.DefineMethod("system.stop", "", "", true, + rb.DefineMethod("system.stop", "", "", FRT_METHOD(RPCHooks::rpc_stop), this); rb.MethodDesc("Shut down the application"); //------------------------------------------------------------------------- diff --git a/slobrok/src/vespa/slobrok/sbregister.cpp b/slobrok/src/vespa/slobrok/sbregister.cpp index 8f8e42a39aa..a1346feeece 100644 --- a/slobrok/src/vespa/slobrok/sbregister.cpp +++ b/slobrok/src/vespa/slobrok/sbregister.cpp @@ -277,12 +277,12 @@ RegisterAPI::RPCHooks::RPCHooks(RegisterAPI &owner) { FRT_ReflectionBuilder rb(&_owner._orb); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(RPCHooks::rpc_listNamesServed), this); rb.MethodDesc("List rpcserver names"); rb.ReturnDesc("names", "The rpcserver names this server wants to serve"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.notifyUnregistered", "s", "", true, + rb.DefineMethod("slobrok.callback.notifyUnregistered", "s", "", FRT_METHOD(RPCHooks::rpc_notifyUnregistered), this); rb.MethodDesc("Notify a server about removed registration"); rb.ParamDesc("name", "RpcServer name"); diff --git a/slobrok/src/vespa/slobrok/server/rpchooks.cpp b/slobrok/src/vespa/slobrok/server/rpchooks.cpp index 33cc10937df..82e30a309a1 100644 --- a/slobrok/src/vespa/slobrok/server/rpchooks.cpp +++ b/slobrok/src/vespa/slobrok/server/rpchooks.cpp @@ -81,39 +81,39 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) FRT_ReflectionBuilder rb(supervisor); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.system.resume", "", "", true, + rb.DefineMethod("slobrok.system.resume", "", "", FRT_METHOD(RPCHooks::rpc_resume), this); rb.MethodDesc("Enable something - currently NOP"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.system.suspend", "", "", true, + rb.DefineMethod("slobrok.system.suspend", "", "", FRT_METHOD(RPCHooks::rpc_suspend), this); rb.MethodDesc("Disable something - currently NOP"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.system.version", "", "s", true, + rb.DefineMethod("slobrok.system.version", "", "s", FRT_METHOD(RPCHooks::rpc_version), this); rb.MethodDesc("Get location broker version"); rb.ReturnDesc("version", "version string"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.system.stop", "", "", true, + rb.DefineMethod("slobrok.system.stop", "", "", FRT_METHOD(RPCHooks::rpc_stop), this); rb.MethodDesc("Shut down the location broker application"); //------------------------------------------------------------------------- //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.internal.listManagedRpcServers", "", "SS", true, + rb.DefineMethod("slobrok.internal.listManagedRpcServers", "", "SS", FRT_METHOD(RPCHooks::rpc_listManagedRpcServers), this); rb.MethodDesc("List all rpcservers managed by this location broker"); rb.ReturnDesc("names", "Managed rpcserver names"); rb.ReturnDesc("specs", "The connection specifications (in same order)"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.internal.lookupManaged", "s", "ss", true, + rb.DefineMethod("slobrok.internal.lookupManaged", "s", "ss", FRT_METHOD(RPCHooks::rpc_lookupManaged), this); rb.MethodDesc("Lookup a specific rpcserver managed by this location broker"); rb.ParamDesc("name", "Name of rpc server"); rb.ReturnDesc("name", "Name of rpc server"); rb.ReturnDesc("spec", "The connection specification"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.internal.wantAdd", "sss", "is", true, + rb.DefineMethod("slobrok.internal.wantAdd", "sss", "is", FRT_METHOD(RPCHooks::rpc_wantAdd), this); rb.MethodDesc("remote location broker wants to add a rpcserver"); rb.ParamDesc("slobrok", "Name of remote location broker"); @@ -122,7 +122,7 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) rb.ReturnDesc("denied", "non-zero if request was denied"); rb.ReturnDesc("reason", "reason for denial"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.internal.doAdd", "sss", "is", true, + rb.DefineMethod("slobrok.internal.doAdd", "sss", "is", FRT_METHOD(RPCHooks::rpc_doAdd), this); rb.MethodDesc("add rpcserver managed by remote location broker"); rb.ParamDesc("slobrok", "Name of remote location broker"); @@ -131,7 +131,7 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) rb.ReturnDesc("denied", "non-zero if request was denied"); rb.ReturnDesc("reason", "reason for denial"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.internal.doRemove", "sss", "is", true, + rb.DefineMethod("slobrok.internal.doRemove", "sss", "is", FRT_METHOD(RPCHooks::rpc_doRemove), this); rb.MethodDesc("remove rpcserver managed by remote location broker"); rb.ParamDesc("slobrok", "Name of remote location broker"); @@ -142,31 +142,31 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) //------------------------------------------------------------------------- //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", true, + rb.DefineMethod("slobrok.callback.listNamesServed", "", "S", FRT_METHOD(RPCHooks::rpc_listNamesServed), this); rb.MethodDesc("List rpcservers served"); rb.ReturnDesc("names", "The rpcserver names this server wants to serve"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.callback.notifyUnregistered", "s", "", true, + rb.DefineMethod("slobrok.callback.notifyUnregistered", "s", "", FRT_METHOD(RPCHooks::rpc_notifyUnregistered), this); rb.MethodDesc("Notify a server about removed registration"); rb.ParamDesc("name", "NamedService name"); //------------------------------------------------------------------------- //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.admin.removePeer", "ss", "", true, + rb.DefineMethod("slobrok.admin.removePeer", "ss", "", FRT_METHOD(RPCHooks::rpc_removePeer), this); rb.MethodDesc("stop syncing with other location broker"); rb.ParamDesc("slobrok", "NamedService name of remote location broker"); rb.ParamDesc("spec", "Connection specification of remote location broker"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.admin.addPeer", "ss", "", true, + rb.DefineMethod("slobrok.admin.addPeer", "ss", "", FRT_METHOD(RPCHooks::rpc_addPeer), this); rb.MethodDesc("sync our information with other location broker"); rb.ParamDesc("slobrok", "NamedService name of remote location broker"); rb.ParamDesc("spec", "Connection specification of remote location broker"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.admin.listAllRpcServers", "", "SSS", true, + rb.DefineMethod("slobrok.admin.listAllRpcServers", "", "SSS", FRT_METHOD(RPCHooks::rpc_listAllRpcServers), this); rb.MethodDesc("List all known rpcservers"); rb.ReturnDesc("names", "NamedService names"); @@ -175,13 +175,13 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) //------------------------------------------------------------------------- //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.unregisterRpcServer", "ss", "", true, + rb.DefineMethod("slobrok.unregisterRpcServer", "ss", "", FRT_METHOD(RPCHooks::rpc_unregisterRpcServer), this); rb.MethodDesc("Unregister a rpcserver"); rb.ParamDesc("name", "NamedService name"); rb.ParamDesc("spec", "The connection specification"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.registerRpcServer", "ss", "", true, + rb.DefineMethod("slobrok.registerRpcServer", "ss", "", FRT_METHOD(RPCHooks::rpc_registerRpcServer), this); rb.MethodDesc("Register a rpcserver"); rb.ParamDesc("name", "NamedService name"); @@ -189,7 +189,7 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) //------------------------------------------------------------------------- //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.mirror.fetch", "ii", "SSi", true, + rb.DefineMethod("slobrok.mirror.fetch", "ii", "SSi", FRT_METHOD(RPCHooks::rpc_mirrorFetch), this); rb.MethodDesc("Fetch or update mirror of name to spec map"); rb.ParamDesc("gencnt", "generation already known by client"); @@ -199,7 +199,7 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) rb.ReturnDesc("specs", "Array of connection specifications (same order)"); rb.ReturnDesc("newgen", "Generation count for new version of the map"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.incremental.fetch", "ii", "iSSSi", true, + rb.DefineMethod("slobrok.incremental.fetch", "ii", "iSSSi", FRT_METHOD(RPCHooks::rpc_incrementalFetch), this); rb.MethodDesc("Fetch or update mirror of name to spec map"); rb.ParamDesc("gencnt", "generation already known by client"); @@ -212,7 +212,7 @@ RPCHooks::initRPC(FRT_Supervisor *supervisor) rb.ReturnDesc("specs", "Array of connection specifications (same order)"); rb.ReturnDesc("newgen", "Generation count for new version of the map"); //------------------------------------------------------------------------- - rb.DefineMethod("slobrok.lookupRpcServer", "s", "SS", true, + rb.DefineMethod("slobrok.lookupRpcServer", "s", "SS", FRT_METHOD(RPCHooks::rpc_lookupRpcServer), this); rb.MethodDesc("Look up rpcservers"); rb.ParamDesc("pattern", "The pattern of the rpcservers to lookup.\n" diff --git a/staging_vespalib/src/tests/util/process_memory_stats/process_memory_stats_test.cpp b/staging_vespalib/src/tests/util/process_memory_stats/process_memory_stats_test.cpp index 45e75547fb2..e4274664336 100644 --- a/staging_vespalib/src/tests/util/process_memory_stats/process_memory_stats_test.cpp +++ b/staging_vespalib/src/tests/util/process_memory_stats/process_memory_stats_test.cpp @@ -80,4 +80,12 @@ TEST("grow mapped memory") munmap(mapAddr, mapLen); } +TEST("order samples") +{ + ProcessMemoryStats a(0,0,0,7,0); + ProcessMemoryStats b(0,0,0,8,0); + EXPECT_TRUE(a < b); + EXPECT_FALSE(b < a); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.cpp b/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.cpp index 138e7a25803..f0cbefd443b 100644 --- a/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.cpp +++ b/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.cpp @@ -4,6 +4,7 @@ #include <vespa/vespalib/stllike/asciistream.h> #include <fstream> #include <sstream> +#include <algorithm> #include <vespa/log/log.h> @@ -182,18 +183,20 @@ ProcessMemoryStats::toString() const ProcessMemoryStats ProcessMemoryStats::create(uint64_t sizeEpsilon) { - ProcessMemoryStats prevStats = createStatsFromSmaps(); - const size_t NUM_TRIES = 10; + constexpr size_t NUM_TRIES = 10; + std::vector<ProcessMemoryStats> samples; + samples.reserve(NUM_TRIES); + samples.push_back(createStatsFromSmaps()); for (size_t i = 0; i < NUM_TRIES; ++i) { - ProcessMemoryStats currStats = createStatsFromSmaps(); - if (prevStats.similarTo(currStats, sizeEpsilon)) { - return prevStats; + samples.push_back(createStatsFromSmaps()); + if (samples.back().similarTo(*(samples.rbegin()+1), sizeEpsilon)) { + return samples.back(); } LOG(info, "create(): Memory stats have changed, trying to read smaps file again: i=%zu, prevStats={%s}, currStats={%s}", - i, prevStats.toString().c_str(), currStats.toString().c_str()); - prevStats = currStats; + i, (samples.rbegin()+1)->toString().c_str(), samples.back().toString().c_str()); } - return prevStats; + std::sort(samples.begin(), samples.end()); + return samples[samples.size()/2]; } } diff --git a/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.h b/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.h index fe5062f75cd..3870a2e2907 100644 --- a/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.h +++ b/staging_vespalib/src/vespa/vespalib/util/process_memory_stats.h @@ -2,7 +2,6 @@ #pragma once -#include <cstdint> #include <vespa/vespalib/stllike/string.h> namespace vespalib { @@ -37,6 +36,7 @@ public: uint64_t getMappingsCount() const { return _mappings_count; } bool similarTo(const ProcessMemoryStats &rhs, uint64_t sizeEpsilon) const; vespalib::string toString() const; + bool operator < (const ProcessMemoryStats & rhs) const { return _anonymous_rss < rhs._anonymous_rss; } /** for unit tests only */ ProcessMemoryStats(uint64_t, uint64_t, uint64_t, uint64_t, uint64_t); diff --git a/standalone-container/src/main/java/com/yahoo/container/standalone/CloudConfigInstallVariables.java b/standalone-container/src/main/java/com/yahoo/container/standalone/CloudConfigInstallVariables.java index be91eceee2f..0be4a55275c 100644 --- a/standalone-container/src/main/java/com/yahoo/container/standalone/CloudConfigInstallVariables.java +++ b/standalone-container/src/main/java/com/yahoo/container/standalone/CloudConfigInstallVariables.java @@ -56,7 +56,7 @@ public class CloudConfigInstallVariables implements CloudConfigOptions { @Override public Optional<Integer> zookeeperQuorumPort() { - return getInstallVariable("zookeeper_quoromPort", Integer::parseInt); + return getInstallVariable("zookeeper_quorumPort", Integer::parseInt); } @Override diff --git a/storage/src/vespa/storage/storageserver/fnetlistener.cpp b/storage/src/vespa/storage/storageserver/fnetlistener.cpp index 1a72190b2a6..e31bded772c 100644 --- a/storage/src/vespa/storage/storageserver/fnetlistener.cpp +++ b/storage/src/vespa/storage/storageserver/fnetlistener.cpp @@ -65,7 +65,7 @@ FNetListener::initRPC() { FRT_ReflectionBuilder rb(_orb.get()); - rb.DefineMethod("getnodestate3", "sii", "ss", true, FRT_METHOD(FNetListener::RPC_getNodeState2), this); + rb.DefineMethod("getnodestate3", "sii", "ss", FRT_METHOD(FNetListener::RPC_getNodeState2), this); rb.MethodDesc("Get state of this node"); rb.ParamDesc("nodestate", "Expected state of given node. If correct, the " "request will be queued on target until it changes. To not give " @@ -74,7 +74,7 @@ FNetListener::initRPC() rb.ReturnDesc("nodestate", "State string for this node"); rb.ReturnDesc("hostinfo", "Information about host this node is running on"); //------------------------------------------------------------------------- - rb.DefineMethod("getnodestate2", "si", "s", true, FRT_METHOD(FNetListener::RPC_getNodeState2), this); + rb.DefineMethod("getnodestate2", "si", "s", FRT_METHOD(FNetListener::RPC_getNodeState2), this); rb.MethodDesc("Get state of this node"); rb.ParamDesc("nodestate", "Expected state of given node. If correct, the " "request will be queued on target until it changes. To not give " @@ -82,17 +82,17 @@ FNetListener::initRPC() rb.ParamDesc("timeout", "Timeout of message in milliseconds, set by the state requester"); rb.ReturnDesc("nodestate", "State string for this node"); //------------------------------------------------------------------------- - rb.DefineMethod("setsystemstate2", "s", "", true, FRT_METHOD(FNetListener::RPC_setSystemState2), this); + rb.DefineMethod("setsystemstate2", "s", "", FRT_METHOD(FNetListener::RPC_setSystemState2), this); rb.MethodDesc("Set systemstate on this node"); rb.ParamDesc("systemstate", "New systemstate to set"); //------------------------------------------------------------------------- - rb.DefineMethod("setdistributionstates", "bix", "", true, FRT_METHOD(FNetListener::RPC_setDistributionStates), this); + rb.DefineMethod("setdistributionstates", "bix", "", FRT_METHOD(FNetListener::RPC_setDistributionStates), this); rb.MethodDesc("Set distribution states for cluster and bucket spaces"); rb.ParamDesc("compressionType", "Compression type for payload"); rb.ParamDesc("uncompressedSize", "Uncompressed size for payload"); rb.ParamDesc("payload", "Binary Slime format payload"); //------------------------------------------------------------------------- - rb.DefineMethod("getcurrenttime", "", "lis", true, FRT_METHOD(FNetListener::RPC_getCurrentTime), this); + rb.DefineMethod("getcurrenttime", "", "lis", FRT_METHOD(FNetListener::RPC_getCurrentTime), this); rb.MethodDesc("Get current time on this node"); rb.ReturnDesc("seconds", "Current time in seconds since epoch"); rb.ReturnDesc("nanoseconds", "additional nanoseconds since epoch"); diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateDeserializer.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateDeserializer.java index 5dd6ceb16b4..59f10a78a58 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateDeserializer.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateDeserializer.java @@ -4,7 +4,7 @@ package com.yahoo.vespa.athenz.client.zts.bindings.serializers; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonDeserializer; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; +import com.yahoo.security.X509CertificateUtils; import java.io.IOException; import java.security.cert.X509Certificate; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateListDeserializer.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateListDeserializer.java index c496031c116..64b23af9295 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateListDeserializer.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zts/bindings/serializers/X509CertificateListDeserializer.java @@ -4,7 +4,7 @@ package com.yahoo.vespa.athenz.client.zts.bindings.serializers; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonDeserializer; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; +import com.yahoo.security.X509CertificateUtils; import java.io.IOException; import java.security.cert.X509Certificate; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identity/SiaIdentityProvider.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identity/SiaIdentityProvider.java index b06ae089b2a..d8fa910aa73 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identity/SiaIdentityProvider.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identity/SiaIdentityProvider.java @@ -5,8 +5,8 @@ import com.google.inject.Inject; import com.yahoo.component.AbstractComponent; import com.yahoo.log.LogLevel; import com.yahoo.vespa.athenz.api.AthenzService; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.athenz.tls.SslContextBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.SslContextBuilder; import com.yahoo.vespa.athenz.utils.SiaUtils; import javax.net.ssl.SSLContext; @@ -92,8 +92,8 @@ public class SiaIdentityProvider extends AbstractComponent implements ServiceIde private SSLContext createIdentitySslContext() { return new SslContextBuilder() - .withTrustStore(trustStoreFile, KeyStoreType.JKS) - .withKeyStore(privateKeyFile, certificateFile) + .withTrustStore(trustStoreFile.toPath(), KeyStoreType.JKS) + .withKeyStore(privateKeyFile.toPath(), certificateFile.toPath()) .build(); } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java index 5567831d49d..4a189c872bc 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java @@ -11,10 +11,10 @@ import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocumentClient; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyUtils; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SslContextBuilder; import com.yahoo.vespa.athenz.tls.Pkcs10Csr; -import com.yahoo.vespa.athenz.tls.SslContextBuilder; import com.yahoo.vespa.athenz.utils.SiaUtils; import com.yahoo.vespa.defaults.Defaults; @@ -31,7 +31,7 @@ import java.time.Clock; import java.time.Duration; import java.util.Optional; -import static com.yahoo.vespa.athenz.tls.KeyStoreType.JKS; +import static com.yahoo.security.KeyStoreType.JKS; import static java.util.Collections.singleton; /** @@ -153,7 +153,7 @@ class AthenzCredentialsService { private SSLContext createIdentitySslContext(PrivateKey privateKey, X509Certificate certificate) { return new SslContextBuilder() .withKeyStore(privateKey, certificate) - .withTrustStore(trustStoreJks, JKS) + .withTrustStore(trustStoreJks.toPath(), JKS) .build(); } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java index 266e2ebcefd..e318ebeb7fd 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java @@ -19,8 +19,8 @@ import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; import com.yahoo.vespa.athenz.client.zts.ZtsClient; import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; import com.yahoo.vespa.athenz.identity.SiaIdentityProvider; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.athenz.tls.SslContextBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.SslContextBuilder; import com.yahoo.vespa.athenz.utils.SiaUtils; import com.yahoo.vespa.defaults.Defaults; @@ -177,7 +177,7 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen X509Certificate roleCertificate = client.getRoleCertificate(role, credentials.getKeyPair(), dnsSuffix); return new SslContextBuilder() .withKeyStore(credentials.getKeyPair().getPrivate(), roleCertificate) - .withTrustStore(getDefaultTrustStoreLocation(), KeyStoreType.JKS) + .withTrustStore(getDefaultTrustStoreLocation().toPath(), KeyStoreType.JKS) .build(); } } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzX509CertificateUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzX509CertificateUtils.java index 46aca707be1..33e5552eaf6 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzX509CertificateUtils.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzX509CertificateUtils.java @@ -8,7 +8,7 @@ import com.yahoo.vespa.athenz.utils.AthenzIdentities; import java.security.cert.X509Certificate; import java.util.List; -import static com.yahoo.vespa.athenz.tls.SubjectAlternativeName.Type.RFC822_NAME; +import static com.yahoo.security.SubjectAlternativeName.Type.RFC822_NAME; /** * Utility methods for Athenz issued x509 certificates @@ -23,26 +23,26 @@ public class AthenzX509CertificateUtils { public static boolean isAthenzRoleCertificate(X509Certificate certificate) { return isAthenzIssuedCertificate(certificate) && - X509CertificateUtils.getSubjectCommonNames(certificate).get(0).contains(COMMON_NAME_ROLE_DELIMITER); + com.yahoo.security.X509CertificateUtils.getSubjectCommonNames(certificate).get(0).contains(COMMON_NAME_ROLE_DELIMITER); } public static boolean isAthenzIssuedCertificate(X509Certificate certificate) { - return X509CertificateUtils.getIssuerCommonNames(certificate).stream() + return com.yahoo.security.X509CertificateUtils.getIssuerCommonNames(certificate).stream() .anyMatch(cn -> cn.equalsIgnoreCase("Yahoo Athenz CA") || cn.equalsIgnoreCase("Athenz AWS CA")); } public static AthenzIdentity getIdentityFromRoleCertificate(X509Certificate certificate) { - List<SubjectAlternativeName> sans = X509CertificateUtils.getSubjectAlternativeNames(certificate); + List<com.yahoo.security.SubjectAlternativeName> sans = com.yahoo.security.X509CertificateUtils.getSubjectAlternativeNames(certificate); return sans.stream() .filter(san -> san.getType() == RFC822_NAME) - .map(SubjectAlternativeName::getValue) + .map(com.yahoo.security.SubjectAlternativeName::getValue) .map(AthenzX509CertificateUtils::getIdentityFromSanEmail) .findFirst() .orElseThrow(() -> new IllegalArgumentException("Could not find identity in SAN: " + sans)); } public static AthenzRole getRolesFromRoleCertificate(X509Certificate certificate) { - String commonName = X509CertificateUtils.getSubjectCommonNames(certificate).get(0); + String commonName = com.yahoo.security.X509CertificateUtils.getSubjectCommonNames(certificate).get(0); int delimiterIndex = commonName.indexOf(COMMON_NAME_ROLE_DELIMITER); String domain = commonName.substring(0, delimiterIndex); String roleName = commonName.substring(delimiterIndex + COMMON_NAME_ROLE_DELIMITER.length()); diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Extension.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Extension.java index 18403669c4d..9a6c20018b8 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Extension.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Extension.java @@ -5,7 +5,9 @@ import org.bouncycastle.asn1.ASN1ObjectIdentifier; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public enum Extension { BASIC_CONSTRAINS(org.bouncycastle.asn1.x509.Extension.basicConstraints), SUBJECT_ALTERNATIVE_NAMES(org.bouncycastle.asn1.x509.Extension.subjectAlternativeName); diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyAlgorithm.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyAlgorithm.java index 4c4198adaac..d685f85b206 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyAlgorithm.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyAlgorithm.java @@ -3,7 +3,9 @@ package com.yahoo.vespa.athenz.tls; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public enum KeyAlgorithm { RSA("RSA"); diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreBuilder.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreBuilder.java index a9279f45129..3e63e441396 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreBuilder.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreBuilder.java @@ -19,7 +19,9 @@ import static java.util.Collections.singletonList; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class KeyStoreBuilder { private final List<KeyEntry> keyEntries = new ArrayList<>(); diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreType.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreType.java index 6c08a60ff5b..b0bfe170789 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreType.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreType.java @@ -9,7 +9,9 @@ import java.security.KeyStoreException; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public enum KeyStoreType { JKS { KeyStore createKeystore() throws KeyStoreException { diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreUtils.java index 12aaa40cce4..96fe76a1f73 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreUtils.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyStoreUtils.java @@ -12,7 +12,9 @@ import java.security.KeyStore; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class KeyStoreUtils { private KeyStoreUtils() {} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyUtils.java index c2be1a40893..fc4734d16ca 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyUtils.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/KeyUtils.java @@ -25,7 +25,9 @@ import java.security.spec.PKCS8EncodedKeySpec; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class KeyUtils { private KeyUtils() {} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10Csr.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10Csr.java index e0029681b23..8138be9d7d8 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10Csr.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10Csr.java @@ -19,7 +19,9 @@ import static java.util.stream.Collectors.toList; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class Pkcs10Csr { private final PKCS10CertificationRequest csr; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilder.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilder.java index 2135f569aeb..702b2f6cd4b 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilder.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrBuilder.java @@ -24,7 +24,9 @@ import static com.yahoo.vespa.athenz.tls.SubjectAlternativeName.Type.DNS_NAME; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class Pkcs10CsrBuilder { private final X500Principal subject; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtils.java index 2289c9ac0ee..be7bb3690bd 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtils.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/Pkcs10CsrUtils.java @@ -13,7 +13,9 @@ import java.io.UncheckedIOException; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class Pkcs10CsrUtils { private Pkcs10CsrUtils() {} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SignatureAlgorithm.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SignatureAlgorithm.java index 2f3e2721751..1ff8ebbe78a 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SignatureAlgorithm.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SignatureAlgorithm.java @@ -3,7 +3,9 @@ package com.yahoo.vespa.athenz.tls; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public enum SignatureAlgorithm { SHA256_WITH_RSA("SHA256withRSA"), SHA512_WITH_RSA("SHA512withRSA"); diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SslContextBuilder.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SslContextBuilder.java index ba5785043da..63262eac048 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SslContextBuilder.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SslContextBuilder.java @@ -18,7 +18,9 @@ import java.security.cert.X509Certificate; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class SslContextBuilder { private KeyStoreSupplier trustStoreSupplier; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SubjectAlternativeName.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SubjectAlternativeName.java index 8b89fc6fe7f..f5b0c7aa1c6 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SubjectAlternativeName.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/SubjectAlternativeName.java @@ -15,7 +15,9 @@ import static java.util.stream.Collectors.toList; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class SubjectAlternativeName { private final Type type; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateBuilder.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateBuilder.java index c27b704f6a3..de593f25f61 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateBuilder.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateBuilder.java @@ -31,7 +31,9 @@ import static com.yahoo.vespa.athenz.tls.SubjectAlternativeName.Type.DNS_NAME; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class X509CertificateBuilder { private final long serialNumber; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateUtils.java index d96ed17765c..8fc25ab06a4 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateUtils.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/X509CertificateUtils.java @@ -30,7 +30,9 @@ import static java.util.stream.Collectors.toList; /** * @author bjorncs + * @deprecated Use com.yahoo.security.* */ +@Deprecated public class X509CertificateUtils { private X509CertificateUtils() {} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/AthenzIdentities.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/AthenzIdentities.java index 82aecc62306..5e01d0cddfc 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/AthenzIdentities.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/AthenzIdentities.java @@ -1,11 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.utils; +import com.yahoo.security.X509CertificateUtils; import com.yahoo.vespa.athenz.api.AthenzDomain; import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.api.AthenzUser; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; import java.security.cert.X509Certificate; import java.util.List; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/SiaUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/SiaUtils.java index 05459e5488b..98d9061be02 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/SiaUtils.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/utils/SiaUtils.java @@ -2,8 +2,8 @@ package com.yahoo.vespa.athenz.utils; import com.yahoo.vespa.athenz.api.AthenzService; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateUtils; import java.io.IOException; import java.io.UncheckedIOException; diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identity/SiaIdentityProviderTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identity/SiaIdentityProviderTest.java index 7b93ffb035d..6217d6fb2ee 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identity/SiaIdentityProviderTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identity/SiaIdentityProviderTest.java @@ -1,15 +1,14 @@ package com.yahoo.vespa.athenz.identity; -import com.google.common.io.Files; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyStoreUtils; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SignatureAlgorithm; +import com.yahoo.security.X509CertificateBuilder; +import com.yahoo.security.X509CertificateUtils; import com.yahoo.vespa.athenz.api.AthenzService; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyStoreBuilder; -import com.yahoo.vespa.athenz.tls.KeyStoreType; -import com.yahoo.vespa.athenz.tls.KeyStoreUtils; -import com.yahoo.vespa.athenz.tls.KeyUtils; -import com.yahoo.vespa.athenz.tls.SignatureAlgorithm; -import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; -import com.yahoo.vespa.athenz.tls.X509CertificateUtils; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -17,7 +16,8 @@ import org.junit.rules.TemporaryFolder; import javax.security.auth.x500.X500Principal; import java.io.File; import java.io.IOException; -import java.nio.charset.StandardCharsets; +import java.math.BigInteger; +import java.nio.file.Files; import java.security.KeyPair; import java.security.KeyStore; import java.security.cert.X509Certificate; @@ -62,12 +62,12 @@ public class SiaIdentityProviderTest { private void createPrivateKeyFile(File keyFile, KeyPair keypair) throws IOException { String privateKeyPem = KeyUtils.toPem(keypair.getPrivate()); - Files.write(privateKeyPem, keyFile, StandardCharsets.UTF_8); + Files.write(keyFile.toPath(), privateKeyPem.getBytes()); } private void createCertificateFile(X509Certificate certificate, File certificateFile) throws IOException { String certificatePem = X509CertificateUtils.toPem(certificate); - Files.write(certificatePem, certificateFile, StandardCharsets.UTF_8); + Files.write(certificateFile.toPath(), certificatePem.getBytes()); } private X509Certificate createCertificate(KeyPair keypair) { @@ -79,7 +79,7 @@ public class SiaIdentityProviderTest { now, now.plus(Duration.ofDays(1)), SignatureAlgorithm.SHA256_WITH_RSA, - 1) + BigInteger.ONE) .build(); } @@ -87,7 +87,7 @@ public class SiaIdentityProviderTest { KeyStore keystore = KeyStoreBuilder.withType(KeyStoreType.JKS) .withCertificateEntry("dummy-cert", certificate) .build(); - KeyStoreUtils.writeKeyStoreToFile(keystore, trustStoreFile); + KeyStoreUtils.writeKeyStoreToFile(keystore, trustStoreFile.toPath()); } }
\ No newline at end of file diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java index 38483bdbaee..4ad58a766e8 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java @@ -1,12 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.identityprovider.client; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.identityprovider.api.IdentityType; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyUtils; import org.junit.Test; import java.security.KeyPair; diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java index 73382d267be..679476abe12 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java @@ -1,24 +1,25 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.utils; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateBuilder; import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier; -import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; import org.junit.Test; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; import java.security.KeyPair; -import java.security.KeyPairGenerator; -import java.security.NoSuchAlgorithmException; import java.security.cert.Certificate; import java.security.cert.X509Certificate; import java.time.Duration; import java.time.Instant; -import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; +import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_ECDSA; import static java.util.Collections.singleton; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -34,23 +35,17 @@ public class AthenzIdentityVerifierTest { public void verifies_certificate_with_athenz_service_as_common_name() throws Exception { AthenzIdentity trustedIdentity = new AthenzService("mydomain", "alice"); AthenzIdentity unknownIdentity = new AthenzService("mydomain", "mallory"); - KeyPair keyPair = createKeyPair(); + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); AthenzIdentityVerifier verifier = new AthenzIdentityVerifier(singleton(trustedIdentity)); assertTrue(verifier.verify("hostname", createSslSessionMock(createSelfSignedCertificate(keyPair, trustedIdentity)))); assertFalse(verifier.verify("hostname", createSslSessionMock(createSelfSignedCertificate(keyPair, unknownIdentity)))); } - private static KeyPair createKeyPair() throws NoSuchAlgorithmException { - KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); - keyGen.initialize(512); - return keyGen.generateKeyPair(); - } - private static X509Certificate createSelfSignedCertificate(KeyPair keyPair, AthenzIdentity identity) { X500Principal x500Name = new X500Principal("CN="+ identity.getFullName()); Instant now = Instant.now(); return X509CertificateBuilder - .fromKeypair(keyPair, x500Name, now, now.plus(Duration.ofDays(30)), SHA256_WITH_RSA, 1) + .fromKeypair(keyPair, x500Name, now, now.plus(Duration.ofDays(30)), SHA256_WITH_ECDSA, BigInteger.ONE) .setBasicConstraints(true, true) .build(); } diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/ntoken/NTokenValidatorTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/ntoken/NTokenValidatorTest.java index 22f97ca8b60..750968a437e 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/ntoken/NTokenValidatorTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/ntoken/NTokenValidatorTest.java @@ -6,8 +6,8 @@ import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzPrincipal; import com.yahoo.vespa.athenz.api.AthenzUser; import com.yahoo.vespa.athenz.api.NToken; -import com.yahoo.vespa.athenz.tls.KeyAlgorithm; -import com.yahoo.vespa.athenz.tls.KeyUtils; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; import com.yahoo.vespa.athenz.utils.ntoken.NTokenValidator.InvalidTokenException; import org.junit.Rule; import org.junit.Test; diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/StatusResponse.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/StatusResponse.java index c6b5a6cb4fe..38f9e229726 100755 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/StatusResponse.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/StatusResponse.java @@ -11,6 +11,10 @@ import java.io.IOException; import java.io.OutputStream; import java.io.OutputStreamWriter; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class StatusResponse extends HttpResponse { MetricManager manager; diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerCompatibility.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerCompatibility.java index b1a7b6dbdeb..dc23589f8ed 100755 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerCompatibility.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerCompatibility.java @@ -9,6 +9,10 @@ import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class VespaFeedHandlerCompatibility extends ThreadedHttpRequestHandler { private final VespaFeedHandlerGet getHandler; diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerGet.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerGet.java index 70631e0e66c..ed4750148bd 100755 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerGet.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerGet.java @@ -11,6 +11,10 @@ import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; import com.yahoo.search.handler.SearchHandler; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class VespaFeedHandlerGet extends ThreadedHttpRequestHandler { private final SearchHandler searchHandler; diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemove.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemove.java index 36ab8090e95..4673efb4605 100755 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemove.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemove.java @@ -21,6 +21,10 @@ import java.io.BufferedReader; import java.io.InputStreamReader; import java.util.concurrent.Executor; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class VespaFeedHandlerRemove extends VespaFeedHandlerBase { @Inject diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemoveLocation.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemoveLocation.java index 04ca6798b4c..ecb911953f6 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemoveLocation.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerRemoveLocation.java @@ -20,6 +20,10 @@ import com.yahoo.vespaclient.config.FeederConfig; import java.util.concurrent.Executor; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class VespaFeedHandlerRemoveLocation extends VespaFeedHandlerBase { @Inject diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerStatus.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerStatus.java index 94ad18fbb51..8c07ea30312 100755 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerStatus.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerStatus.java @@ -15,6 +15,10 @@ import com.yahoo.metrics.MetricManager; import com.yahoo.metrics.MetricSet; import com.yahoo.vespaclient.config.FeederConfig; +/** + * @deprecated Legacy API. Will be removed in Vespa 7 + */ +@Deprecated public class VespaFeedHandlerStatus extends ThreadedHttpRequestHandler { private MetricManager manager; diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerVisit.java b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerVisit.java index c9af0933799..5b5224775cb 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerVisit.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/feedhandler/VespaFeedHandlerVisit.java @@ -13,7 +13,10 @@ import com.yahoo.search.handler.SearchHandler; /** * @author thomasg + * + * @deprecated Legacy API. Will be removed in Vespa 7 */ +@Deprecated public class VespaFeedHandlerVisit extends ThreadedHttpRequestHandler { private final SearchHandler searchHandler; diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/feedhandler/VespaFeedHandlerTestCase.java b/vespaclient-container-plugin/src/test/java/com/yahoo/feedhandler/VespaFeedHandlerTestCase.java index d1ed02209b2..fcc4e18d66e 100755 --- a/vespaclient-container-plugin/src/test/java/com/yahoo/feedhandler/VespaFeedHandlerTestCase.java +++ b/vespaclient-container-plugin/src/test/java/com/yahoo/feedhandler/VespaFeedHandlerTestCase.java @@ -39,6 +39,7 @@ import java.util.zip.GZIPOutputStream; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +@SuppressWarnings("deprecation") // VespaFeedHandler classes are going away on Vespa 7 public class VespaFeedHandlerTestCase { private VespaFeedHandler feedHandler; diff --git a/vespaclient/src/perl/lib/Yahoo/Vespa/VespaModel.pm b/vespaclient/src/perl/lib/Yahoo/Vespa/VespaModel.pm index fd324540bba..18af0bbdecd 100644 --- a/vespaclient/src/perl/lib/Yahoo/Vespa/VespaModel.pm +++ b/vespaclient/src/perl/lib/Yahoo/Vespa/VespaModel.pm @@ -162,7 +162,7 @@ sub setModelRetrievalFunction { # (Function) } sub retrieveModelConfigDefault { # () my $VESPA_HOME= $ENV{'VESPA_HOME'}; - my $cmd = ${VESPA_HOME} . '/bin/vespa-get-config -n cloud.config.model -i admin/model'; + my $cmd = ${VESPA_HOME} . '/bin/vespa-get-config -l -n cloud.config.model -i admin/model'; if (defined $CONFIG_REQUEST_TIMEOUT) { $cmd .= " -w $CONFIG_REQUEST_TIMEOUT"; diff --git a/vespajlib/pom.xml b/vespajlib/pom.xml index 880d039bc54..5b9c143a447 100644 --- a/vespajlib/pom.xml +++ b/vespajlib/pom.xml @@ -17,29 +17,27 @@ </description> <dependencies> - <dependency> - <groupId>com.google.guava</groupId> - <artifactId>guava</artifactId> - <scope>provided</scope> - </dependency> + + <!-- compile scope --> <dependency> <groupId>net.jpountz.lz4</groupId> <artifactId>lz4</artifactId> </dependency> <dependency> - <groupId>org.hamcrest</groupId> - <artifactId>hamcrest-library</artifactId> - <scope>test</scope> + <groupId>commons-lang</groupId> + <artifactId>commons-lang</artifactId> </dependency> <dependency> - <groupId>org.mockito</groupId> - <artifactId>mockito-core</artifactId> - <scope>test</scope> + <groupId>org.apache.commons</groupId> + <artifactId>commons-exec</artifactId> </dependency> + + + <!-- provided scope --> <dependency> - <groupId>junit</groupId> - <artifactId>junit</artifactId> - <scope>test</scope> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <scope>provided</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> @@ -54,22 +52,41 @@ <scope>provided</scope> </dependency> <dependency> - <groupId>commons-lang</groupId> - <artifactId>commons-lang</artifactId> + <groupId>org.bouncycastle</groupId> + <artifactId>bcprov-jdk15on</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.bouncycastle</groupId> + <artifactId>bcpkix-jdk15on</artifactId> + <scope>provided</scope> </dependency> <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-core</artifactId> - <scope>test</scope> + <scope>provided</scope> </dependency> <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-databind</artifactId> + <scope>provided</scope> + </dependency> + + <!-- test scope --> + <dependency> + <groupId>org.hamcrest</groupId> + <artifactId>hamcrest-library</artifactId> <scope>test</scope> </dependency> <dependency> - <groupId>org.apache.commons</groupId> - <artifactId>commons-exec</artifactId> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <scope>test</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> @@ -77,6 +94,7 @@ <version>${project.version}</version> <scope>test</scope> </dependency> + </dependencies> <build> <plugins> diff --git a/vespajlib/src/main/java/com/yahoo/security/BasicConstraintsExtension.java b/vespajlib/src/main/java/com/yahoo/security/BasicConstraintsExtension.java new file mode 100644 index 00000000000..d3c08ba27d0 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/BasicConstraintsExtension.java @@ -0,0 +1,14 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +/** + * @author bjorncs + */ +class BasicConstraintsExtension { + final boolean isCritical, isCertAuthorityCertificate; + + BasicConstraintsExtension(boolean isCritical, boolean isCertAuthorityCertificate) { + this.isCritical = isCritical; + this.isCertAuthorityCertificate = isCertAuthorityCertificate; + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/BouncyCastleProviderHolder.java b/vespajlib/src/main/java/com/yahoo/security/BouncyCastleProviderHolder.java new file mode 100644 index 00000000000..48a23a1fe7e --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/BouncyCastleProviderHolder.java @@ -0,0 +1,14 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.jce.provider.BouncyCastleProvider; + +/** + * @author bjorncs + */ +class BouncyCastleProviderHolder { + + private static final BouncyCastleProvider bcProvider = new BouncyCastleProvider(); + + static BouncyCastleProvider getInstance() { return bcProvider; } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/Extension.java b/vespajlib/src/main/java/com/yahoo/security/Extension.java new file mode 100644 index 00000000000..46b781c9c86 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/Extension.java @@ -0,0 +1,22 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.asn1.ASN1ObjectIdentifier; + +/** + * @author bjorncs + */ +public enum Extension { + BASIC_CONSTRAINTS(org.bouncycastle.asn1.x509.Extension.basicConstraints), + SUBJECT_ALTERNATIVE_NAMES(org.bouncycastle.asn1.x509.Extension.subjectAlternativeName); + + final ASN1ObjectIdentifier extensionOId; + + Extension(ASN1ObjectIdentifier extensionOId) { + this.extensionOId = extensionOId; + } + + public String getOId() { + return extensionOId.getId(); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/KeyAlgorithm.java b/vespajlib/src/main/java/com/yahoo/security/KeyAlgorithm.java new file mode 100644 index 00000000000..3218f81f0d6 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/KeyAlgorithm.java @@ -0,0 +1,20 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +/** + * @author bjorncs + */ +public enum KeyAlgorithm { + RSA("RSA"), + EC("EC"); + + final String algorithmName; + + KeyAlgorithm(String algorithmName) { + this.algorithmName = algorithmName; + } + + String getAlgorithmName() { + return algorithmName; + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/KeyStoreBuilder.java b/vespajlib/src/main/java/com/yahoo/security/KeyStoreBuilder.java new file mode 100644 index 00000000000..2160fbf6455 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/KeyStoreBuilder.java @@ -0,0 +1,121 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; + +import static java.util.Collections.singletonList; + +/** + * @author bjorncs + */ +public class KeyStoreBuilder { + + private final List<KeyEntry> keyEntries = new ArrayList<>(); + private final List<CertificateEntry> certificateEntries = new ArrayList<>(); + + private final KeyStoreType keyStoreType; + private Path inputFile; + private char[] inputFilePassword; + + private KeyStoreBuilder(KeyStoreType keyStoreType) { + this.keyStoreType = keyStoreType; + } + + public static KeyStoreBuilder withType(KeyStoreType type) { + return new KeyStoreBuilder(type); + } + + public KeyStoreBuilder fromFile(Path file, char[] password) { + this.inputFile = file; + this.inputFilePassword = password; + return this; + } + + public KeyStoreBuilder fromFile(Path file) { + return fromFile(file, null); + } + + public KeyStoreBuilder withKeyEntry(String alias, PrivateKey privateKey, char[] password, List<X509Certificate> certificateChain) { + keyEntries.add(new KeyEntry(alias, privateKey, certificateChain, password)); + return this; + } + + public KeyStoreBuilder withKeyEntry(String alias, PrivateKey privateKey, char[] password, X509Certificate certificate) { + return withKeyEntry(alias, privateKey, password, singletonList(certificate)); + } + + public KeyStoreBuilder withKeyEntry(String alias, PrivateKey privateKey, X509Certificate certificate) { + return withKeyEntry(alias, privateKey, null, certificate); + } + + public KeyStoreBuilder withKeyEntry(String alias, PrivateKey privateKey, List<X509Certificate> certificateChain) { + return withKeyEntry(alias, privateKey, null, certificateChain); + } + + public KeyStoreBuilder withCertificateEntry(String alias, X509Certificate certificate) { + certificateEntries.add(new CertificateEntry(alias, certificate)); + return this; + } + + public KeyStore build() { + try { + KeyStore keystore = this.keyStoreType.createKeystore(); + if (this.inputFile != null) { + try (InputStream in = new BufferedInputStream(Files.newInputStream(this.inputFile))) { + keystore.load(in, this.inputFilePassword); + } + } else { + keystore.load(null); + } + for (KeyEntry entry : keyEntries) { + char[] password = entry.password != null ? entry.password : new char[0]; + Certificate[] certificateChain = entry.certificateChain.toArray(new Certificate[entry.certificateChain.size()]); + keystore.setKeyEntry(entry.alias, entry.privateKey, password, certificateChain); + } + for (CertificateEntry entry : certificateEntries) { + keystore.setCertificateEntry(entry.alias, entry.certificate); + } + return keystore; + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static class KeyEntry { + final String alias; + final PrivateKey privateKey; + final List<X509Certificate> certificateChain; + final char[] password; + + KeyEntry(String alias, PrivateKey privateKey, List<X509Certificate> certificateChain, char[] password) { + this.alias = alias; + this.privateKey = privateKey; + this.certificateChain = certificateChain; + this.password = password; + } + } + + private static class CertificateEntry { + final String alias; + final X509Certificate certificate; + + CertificateEntry(String alias, X509Certificate certificate) { + this.alias = alias; + this.certificate = certificate; + } + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/KeyStoreType.java b/vespajlib/src/main/java/com/yahoo/security/KeyStoreType.java new file mode 100644 index 00000000000..7fb8df35286 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/KeyStoreType.java @@ -0,0 +1,23 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.KeyStoreException; + +/** + * @author bjorncs + */ +public enum KeyStoreType { + JKS { + KeyStore createKeystore() throws KeyStoreException { + return KeyStore.getInstance("JKS"); + } + }, + PKCS12 { + KeyStore createKeystore() throws KeyStoreException { + return KeyStore.getInstance("PKCS12", BouncyCastleProviderHolder.getInstance()); + } + }; + abstract KeyStore createKeystore() throws GeneralSecurityException; +} diff --git a/vespajlib/src/main/java/com/yahoo/security/KeyStoreUtils.java b/vespajlib/src/main/java/com/yahoo/security/KeyStoreUtils.java new file mode 100644 index 00000000000..f0c4d99bf69 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/KeyStoreUtils.java @@ -0,0 +1,34 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import java.io.BufferedOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.GeneralSecurityException; +import java.security.KeyStore; + +/** + * @author bjorncs + */ +public class KeyStoreUtils { + private KeyStoreUtils() {} + + public static void writeKeyStoreToFile(KeyStore keyStore, Path file, char[] password) { + try (OutputStream out = new BufferedOutputStream(Files.newOutputStream(file))) { + keyStore.store(out, password); + } catch (IOException e) { + throw new UncheckedIOException(e); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + + } + + public static void writeKeyStoreToFile(KeyStore keyStore, Path file) { + writeKeyStoreToFile(keyStore, file, new char[0]); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/security/KeyUtils.java b/vespajlib/src/main/java/com/yahoo/security/KeyUtils.java new file mode 100644 index 00000000000..1c3157d639f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/KeyUtils.java @@ -0,0 +1,121 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.asn1.ASN1Encodable; +import org.bouncycastle.asn1.ASN1Primitive; +import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; +import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey; +import org.bouncycastle.jce.spec.ECParameterSpec; +import org.bouncycastle.jce.spec.ECPublicKeySpec; +import org.bouncycastle.math.ec.ECPoint; +import org.bouncycastle.math.ec.FixedPointCombMultiplier; +import org.bouncycastle.openssl.PEMKeyPair; +import org.bouncycastle.openssl.PEMParser; +import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter; +import org.bouncycastle.openssl.jcajce.JcaPEMWriter; +import org.bouncycastle.util.io.pem.PemObject; + +import java.io.IOException; +import java.io.StringReader; +import java.io.StringWriter; +import java.io.UncheckedIOException; +import java.security.GeneralSecurityException; +import java.security.KeyFactory; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.interfaces.RSAPrivateCrtKey; +import java.security.spec.PKCS8EncodedKeySpec; +import java.security.spec.RSAPublicKeySpec; + +import static com.yahoo.security.KeyAlgorithm.EC; +import static com.yahoo.security.KeyAlgorithm.RSA; + +/** + * @author bjorncs + */ +// TODO Support serialization of EC private keys +public class KeyUtils { + private KeyUtils() {} + + public static KeyPair generateKeypair(KeyAlgorithm algorithm, int keySize) { + try { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm.getAlgorithmName(), BouncyCastleProviderHolder.getInstance()); + if (keySize != -1) { + keyGen.initialize(keySize); + } + return keyGen.genKeyPair(); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + public static KeyPair generateKeypair(KeyAlgorithm algorithm) { + return generateKeypair(algorithm, -1); + } + + public static PublicKey extractPublicKey(PrivateKey privateKey) { + String algorithm = privateKey.getAlgorithm(); + try { + if (algorithm.equals(RSA.getAlgorithmName())) { + KeyFactory keyFactory = KeyFactory.getInstance(RSA.getAlgorithmName(), BouncyCastleProviderHolder.getInstance()); + RSAPrivateCrtKey rsaPrivateCrtKey = (RSAPrivateCrtKey) privateKey; + RSAPublicKeySpec keySpec = new RSAPublicKeySpec(rsaPrivateCrtKey.getModulus(), rsaPrivateCrtKey.getPublicExponent()); + return keyFactory.generatePublic(keySpec); + } else if (algorithm.equals(EC.getAlgorithmName())) { + KeyFactory keyFactory = KeyFactory.getInstance(EC.getAlgorithmName(), BouncyCastleProviderHolder.getInstance()); + BCECPrivateKey ecPrivateKey = (BCECPrivateKey) privateKey; + ECParameterSpec ecParameterSpec = ecPrivateKey.getParameters(); + ECPoint ecPoint = new FixedPointCombMultiplier().multiply(ecParameterSpec.getG(), ecPrivateKey.getD()); + ECPublicKeySpec keySpec = new ECPublicKeySpec(ecPoint, ecParameterSpec); + return keyFactory.generatePublic(keySpec); + } else { + throw new IllegalArgumentException("Unexpected key algorithm: " + algorithm); + } + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + public static PrivateKey fromPemEncodedPrivateKey(String pem) { + try (PEMParser parser = new PEMParser(new StringReader(pem))) { + Object pemObject = parser.readObject(); + if (pemObject instanceof PrivateKeyInfo) { + PrivateKeyInfo keyInfo = (PrivateKeyInfo) pemObject; + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(keyInfo.getEncoded()); + return KeyFactory.getInstance(RSA.getAlgorithmName()).generatePrivate(keySpec); + } else if (pemObject instanceof PEMKeyPair) { + PEMKeyPair pemKeypair = (PEMKeyPair) pemObject; + PrivateKeyInfo keyInfo = pemKeypair.getPrivateKeyInfo(); + JcaPEMKeyConverter pemConverter = new JcaPEMKeyConverter(); + return pemConverter.getPrivateKey(keyInfo); + } + throw new IllegalArgumentException("Unexpected type of PEM type: " + pemObject); + } catch (IOException e) { + throw new UncheckedIOException(e); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + public static String toPem(PrivateKey privateKey) { + try (StringWriter stringWriter = new StringWriter(); JcaPEMWriter pemWriter = new JcaPEMWriter(stringWriter)) { + // Note: Encoding using PKCS#1 as this is to be read by tools only supporting PKCS#1 + pemWriter.writeObject(new PemObject("RSA PRIVATE KEY", getPkcs1Bytes(privateKey))); + pemWriter.flush(); + return stringWriter.toString(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static byte[] getPkcs1Bytes(PrivateKey privateKey) throws IOException{ + + byte[] privBytes = privateKey.getEncoded(); + PrivateKeyInfo pkInfo = PrivateKeyInfo.getInstance(privBytes); + ASN1Encodable encodable = pkInfo.parsePrivateKey(); + ASN1Primitive primitive = encodable.toASN1Primitive(); + return primitive.getEncoded(); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/Pkcs10Csr.java b/vespajlib/src/main/java/com/yahoo/security/Pkcs10Csr.java new file mode 100644 index 00000000000..e08ee117fcd --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/Pkcs10Csr.java @@ -0,0 +1,71 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.asn1.ASN1ObjectIdentifier; +import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers; +import org.bouncycastle.asn1.x509.BasicConstraints; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.Extensions; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.pkcs.PKCS10CertificationRequest; + +import javax.security.auth.x500.X500Principal; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static java.util.Collections.emptyList; +import static java.util.stream.Collectors.toList; + +/** + * @author bjorncs + */ +public class Pkcs10Csr { + + private final PKCS10CertificationRequest csr; + + Pkcs10Csr(PKCS10CertificationRequest csr) { + this.csr = csr; + } + + PKCS10CertificationRequest getBcCsr() { + return csr; + } + + public X500Principal getSubject() { + return new X500Principal(csr.getSubject().toString()); + } + + public List<SubjectAlternativeName> getSubjectAlternativeNames() { + return getExtensions() + .map(extensions -> GeneralNames.fromExtensions(extensions, Extension.subjectAlternativeName)) + .map(SubjectAlternativeName::fromGeneralNames) + .orElse(emptyList()); + } + + /** + * @return If basic constraints extension is present: returns true if CA cert, false otherwise. Returns empty if the extension is not present. + */ + public Optional<Boolean> getBasicConstraints() { + return getExtensions() + .map(BasicConstraints::fromExtensions) + .map(BasicConstraints::isCA); + } + + public List<String> getExtensionOIds() { + return getExtensions() + .map(extensions -> Arrays.stream(extensions.getExtensionOIDs()) + .map(ASN1ObjectIdentifier::getId) + .collect(toList())) + .orElse(emptyList()); + + } + + private Optional<Extensions> getExtensions() { + return Optional.of(csr.getAttributes(PKCSObjectIdentifiers.pkcs_9_at_extensionRequest)) + .filter(attributes -> attributes.length > 0) + .map(attributes -> attributes[0]) + .map(attribute -> Extensions.getInstance(attribute.getAttrValues().getObjectAt(0))); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrBuilder.java b/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrBuilder.java new file mode 100644 index 00000000000..b46293b2e2f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrBuilder.java @@ -0,0 +1,105 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers; +import org.bouncycastle.asn1.x509.BasicConstraints; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.ExtensionsGenerator; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.OperatorCreationException; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; +import org.bouncycastle.pkcs.PKCS10CertificationRequestBuilder; +import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder; + +import javax.security.auth.x500.X500Principal; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.security.KeyPair; +import java.util.ArrayList; +import java.util.List; + +import static com.yahoo.security.SubjectAlternativeName.Type.DNS_NAME; + +/** + * @author bjorncs + */ +public class Pkcs10CsrBuilder { + + private final X500Principal subject; + private final KeyPair keyPair; + private final List<SubjectAlternativeName> subjectAlternativeNames = new ArrayList<>(); + private final SignatureAlgorithm signatureAlgorithm; + private BasicConstraintsExtension basicConstraintsExtension; + + private Pkcs10CsrBuilder(X500Principal subject, + KeyPair keyPair, + SignatureAlgorithm signatureAlgorithm) { + this.subject = subject; + this.keyPair = keyPair; + this.signatureAlgorithm = signatureAlgorithm; + } + + public static Pkcs10CsrBuilder fromKeypair(X500Principal subject, + KeyPair keyPair, + SignatureAlgorithm signatureAlgorithm) { + return new Pkcs10CsrBuilder(subject, keyPair, signatureAlgorithm); + } + + public Pkcs10CsrBuilder addSubjectAlternativeName(String dns) { + this.subjectAlternativeNames.add(new SubjectAlternativeName(DNS_NAME, dns)); + return this; + } + + public Pkcs10CsrBuilder addSubjectAlternativeName(SubjectAlternativeName san) { + this.subjectAlternativeNames.add(san); + return this; + } + + public Pkcs10CsrBuilder addSubjectAlternativeName(SubjectAlternativeName.Type type, String value) { + this.subjectAlternativeNames.add(new SubjectAlternativeName(type, value)); + return this; + } + + public Pkcs10CsrBuilder setBasicConstraints(boolean isCritical, boolean isCertAuthorityCertificate) { + this.basicConstraintsExtension = new BasicConstraintsExtension(isCritical, isCertAuthorityCertificate); + return this; + } + + public Pkcs10CsrBuilder setIsCertAuthority(boolean isCertAuthority) { + return setBasicConstraints(true, isCertAuthority); + } + + public Pkcs10Csr build() { + try { + PKCS10CertificationRequestBuilder requestBuilder = + new JcaPKCS10CertificationRequestBuilder(subject, keyPair.getPublic()); + ExtensionsGenerator extGen = new ExtensionsGenerator(); + if (basicConstraintsExtension != null) { + extGen.addExtension( + Extension.basicConstraints, + basicConstraintsExtension.isCritical, + new BasicConstraints(basicConstraintsExtension.isCertAuthorityCertificate)); + } + if (!subjectAlternativeNames.isEmpty()) { + GeneralNames generalNames = new GeneralNames( + subjectAlternativeNames.stream() + .map(SubjectAlternativeName::toGeneralName) + .toArray(GeneralName[]::new)); + extGen.addExtension(Extension.subjectAlternativeName, false, generalNames); + } + requestBuilder.addAttribute(PKCSObjectIdentifiers.pkcs_9_at_extensionRequest, extGen.generate()); + ContentSigner contentSigner = new JcaContentSignerBuilder(signatureAlgorithm.getAlgorithmName()) + .setProvider(BouncyCastleProviderHolder.getInstance()) + .build(keyPair.getPrivate()); + return new Pkcs10Csr(requestBuilder.build(contentSigner)); + } catch (OperatorCreationException e) { + throw new RuntimeException(e); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrUtils.java b/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrUtils.java new file mode 100644 index 00000000000..6f12450528d --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/Pkcs10CsrUtils.java @@ -0,0 +1,38 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.openssl.PEMParser; +import org.bouncycastle.openssl.jcajce.JcaPEMWriter; +import org.bouncycastle.pkcs.PKCS10CertificationRequest; +import org.bouncycastle.util.io.pem.PemObject; + +import java.io.IOException; +import java.io.StringReader; +import java.io.StringWriter; +import java.io.UncheckedIOException; + +/** + * @author bjorncs + */ +public class Pkcs10CsrUtils { + + private Pkcs10CsrUtils() {} + + public static Pkcs10Csr fromPem(String pem) { + try (PEMParser pemParser = new PEMParser(new StringReader(pem))) { + return new Pkcs10Csr((PKCS10CertificationRequest) pemParser.readObject()); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public static String toPem(Pkcs10Csr csr) { + try (StringWriter stringWriter = new StringWriter(); JcaPEMWriter pemWriter = new JcaPEMWriter(stringWriter)) { + pemWriter.writeObject(new PemObject("CERTIFICATE REQUEST", csr.getBcCsr().getEncoded())); + pemWriter.flush(); + return stringWriter.toString(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/SignatureAlgorithm.java b/vespajlib/src/main/java/com/yahoo/security/SignatureAlgorithm.java new file mode 100644 index 00000000000..fbff18f5c12 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/SignatureAlgorithm.java @@ -0,0 +1,22 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +/** + * @author bjorncs + */ +public enum SignatureAlgorithm { + SHA256_WITH_RSA("SHA256withRSA"), + SHA512_WITH_RSA("SHA512withRSA"), + SHA256_WITH_ECDSA("SHA256withECDSA"), + SHA512_WITH_ECDSA("SHA512withECDSA"); + + private final String algorithmName; + + SignatureAlgorithm(String algorithmName) { + this.algorithmName = algorithmName; + } + + public String getAlgorithmName() { + return algorithmName; + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/SslContextBuilder.java b/vespajlib/src/main/java/com/yahoo/security/SslContextBuilder.java new file mode 100644 index 00000000000..75ab2417edf --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/SslContextBuilder.java @@ -0,0 +1,137 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.Collections; +import java.util.List; + +import static java.util.Collections.singletonList; + +/** + * @author bjorncs + */ +public class SslContextBuilder { + + private KeyStoreSupplier trustStoreSupplier; + private KeyStoreSupplier keyStoreSupplier; + private char[] keyStorePassword; + + public SslContextBuilder() {} + + public SslContextBuilder withTrustStore(Path file, KeyStoreType trustStoreType) { + this.trustStoreSupplier = () -> KeyStoreBuilder.withType(trustStoreType).fromFile(file).build(); + return this; + } + + public SslContextBuilder withTrustStore(KeyStore trustStore) { + this.trustStoreSupplier = () -> trustStore; + return this; + } + + public SslContextBuilder withTrustStore(X509Certificate caCertificate) { + return withTrustStore(singletonList(caCertificate)); + } + + public SslContextBuilder withTrustStore(List<X509Certificate> caCertificates) { + this.trustStoreSupplier = () -> createTrustStore(caCertificates); + return this; + } + + public SslContextBuilder withTrustStore(Path pemEncodedCaCertificates) { + this.trustStoreSupplier = () -> { + List<X509Certificate> caCertificates = + X509CertificateUtils.certificateListFromPem(new String(Files.readAllBytes(pemEncodedCaCertificates))); + return createTrustStore(caCertificates); + }; + return this; + } + + public SslContextBuilder withKeyStore(PrivateKey privateKey, X509Certificate certificate) { + char[] pwd = new char[0]; + this.keyStoreSupplier = () -> KeyStoreBuilder.withType(KeyStoreType.JKS).withKeyEntry("default", privateKey, certificate).build(); + this.keyStorePassword = pwd; + return this; + } + + public SslContextBuilder withKeyStore(KeyStore keyStore, char[] password) { + this.keyStoreSupplier = () -> keyStore; + this.keyStorePassword = password; + return this; + } + + public SslContextBuilder withKeyStore(Path file, char[] password, KeyStoreType keyStoreType) { + this.keyStoreSupplier = () -> KeyStoreBuilder.withType(keyStoreType).fromFile(file, password).build(); + this.keyStorePassword = password; + return this; + } + + public SslContextBuilder withKeyStore(Path privateKeyPemFile, Path certificatesPemFile) { + this.keyStoreSupplier = + () -> { + PrivateKey privateKey = KeyUtils.fromPemEncodedPrivateKey(new String(Files.readAllBytes(privateKeyPemFile))); + List<X509Certificate> certificates = X509CertificateUtils.certificateListFromPem(new String(Files.readAllBytes(certificatesPemFile))); + return KeyStoreBuilder.withType(KeyStoreType.JKS) + .withKeyEntry("default", privateKey, certificates) + .build(); + }; + this.keyStorePassword = new char[0]; + return this; + } + + public SSLContext build() { + try { + SSLContext sslContext = SSLContext.getInstance("TLSv1.2"); + TrustManager[] trustManagers = + trustStoreSupplier != null ? createTrustManagers(trustStoreSupplier) : null; + KeyManager[] keyManagers = + keyStoreSupplier != null ? createKeyManagers(keyStoreSupplier, keyStorePassword) : null; + sslContext.init(keyManagers, trustManagers, null); + return sslContext; + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static TrustManager[] createTrustManagers(KeyStoreSupplier trustStoreSupplier) + throws GeneralSecurityException, IOException { + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(trustStoreSupplier.get()); + return trustManagerFactory.getTrustManagers(); + } + + private static KeyManager[] createKeyManagers(KeyStoreSupplier keyStoreSupplier, char[] password) + throws GeneralSecurityException, IOException { + KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStoreSupplier.get(), password); + return keyManagerFactory.getKeyManagers(); + } + + private static KeyStore createTrustStore(List<X509Certificate> caCertificates) { + KeyStoreBuilder trustStoreBuilder = KeyStoreBuilder.withType(KeyStoreType.JKS); + for (int i = 0; i < caCertificates.size(); i++) { + trustStoreBuilder.withCertificateEntry("cert-" + i, caCertificates.get(i)); + } + return trustStoreBuilder.build(); + } + + private interface KeyStoreSupplier { + KeyStore get() throws IOException, GeneralSecurityException; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/security/SubjectAlternativeName.java b/vespajlib/src/main/java/com/yahoo/security/SubjectAlternativeName.java new file mode 100644 index 00000000000..29395c75e70 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/SubjectAlternativeName.java @@ -0,0 +1,114 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.asn1.ASN1Encodable; +import org.bouncycastle.asn1.DERIA5String; +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.GeneralNames; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static java.util.stream.Collectors.toList; + +/** + * @author bjorncs + */ +public class SubjectAlternativeName { + + private final Type type; + private final String value; + + public SubjectAlternativeName(Type type, String value) { + this.type = type; + this.value = value; + } + + SubjectAlternativeName(GeneralName bcGeneralName) { + this.type = Type.fromTag(bcGeneralName.getTagNo()); + this.value = getValue(bcGeneralName); + } + + public Type getType() { + return type; + } + + public String getValue() { + return value; + } + + GeneralName toGeneralName() { + return new GeneralName(type.tag, value); + } + + static List<SubjectAlternativeName> fromGeneralNames(GeneralNames generalNames) { + return Arrays.stream(generalNames.getNames()).map(SubjectAlternativeName::new).collect(toList()); + } + + private String getValue(GeneralName bcGeneralName) { + ASN1Encodable name = bcGeneralName.getName(); + switch (bcGeneralName.getTagNo()) { + case GeneralName.rfc822Name: + case GeneralName.dNSName: + case GeneralName.uniformResourceIdentifier: + return DERIA5String.getInstance(name).getString(); + case GeneralName.directoryName: + return X500Name.getInstance(name).toString(); + default: + return name.toString(); + } + } + + @Override + public String toString() { + return "SubjectAlternativeName{" + + "type=" + type + + ", value='" + value + '\'' + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SubjectAlternativeName that = (SubjectAlternativeName) o; + return type == that.type && + Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(type, value); + } + + public enum Type { + OTHER_NAME(0), + RFC822_NAME(1), + DNS_NAME(2), + X400_ADDRESS(3), + DIRECTORY_NAME(4), + EDI_PARITY_NAME(5), + UNIFORM_RESOURCE_IDENTIFIER(6), + IP_ADDRESS(7), + REGISTERED_ID(8); + + final int tag; + + Type(int tag) { + this.tag = tag; + } + + public static Type fromTag(int tag) { + return Arrays.stream(Type.values()) + .filter(type -> type.tag == tag) + .findAny() + .orElseThrow(() -> new IllegalArgumentException("Invalid tag: " + tag)); + } + + public int getTag() { + return tag; + } + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/X509CertificateBuilder.java b/vespajlib/src/main/java/com/yahoo/security/X509CertificateBuilder.java new file mode 100644 index 00000000000..54d7d39253e --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/X509CertificateBuilder.java @@ -0,0 +1,167 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.asn1.x509.BasicConstraints; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.OperatorException; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; +import org.bouncycastle.pkcs.PKCS10CertificationRequest; +import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequest; + +import javax.security.auth.x500.X500Principal; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.SecureRandom; +import java.security.cert.X509Certificate; +import java.sql.Date; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; + +import static com.yahoo.security.SubjectAlternativeName.Type.DNS_NAME; + + +/** + * @author bjorncs + */ +public class X509CertificateBuilder { + + private final BigInteger serialNumber; + private final SignatureAlgorithm signingAlgorithm; + private final PrivateKey caPrivateKey; + private final Instant notBefore; + private final Instant notAfter; + private final List<SubjectAlternativeName> subjectAlternativeNames = new ArrayList<>(); + private final X500Principal issuer; + private final X500Principal subject; + private final PublicKey certPublicKey; + private BasicConstraintsExtension basicConstraintsExtension; + + private X509CertificateBuilder(X500Principal issuer, + X500Principal subject, + Instant notBefore, + Instant notAfter, + PublicKey certPublicKey, + PrivateKey caPrivateKey, + SignatureAlgorithm signingAlgorithm, + BigInteger serialNumber) { + this.issuer = issuer; + this.subject = subject; + this.notBefore = notBefore; + this.notAfter = notAfter; + this.certPublicKey = certPublicKey; + this.caPrivateKey = caPrivateKey; + this.signingAlgorithm = signingAlgorithm; + this.serialNumber = serialNumber; + } + + public static X509CertificateBuilder fromCsr(Pkcs10Csr csr, + X500Principal caIssuer, + Instant notBefore, + Instant notAfter, + PrivateKey caPrivateKey, + SignatureAlgorithm signingAlgorithm, + BigInteger serialNumber) { + try { + PKCS10CertificationRequest bcCsr = csr.getBcCsr(); + PublicKey publicKey = new JcaPKCS10CertificationRequest(bcCsr) + .setProvider(BouncyCastleProviderHolder.getInstance()) + .getPublicKey(); + return new X509CertificateBuilder(caIssuer, + new X500Principal(bcCsr.getSubject().getEncoded()), + notBefore, + notAfter, + publicKey, + caPrivateKey, + signingAlgorithm, + serialNumber); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public static X509CertificateBuilder fromKeypair(KeyPair keyPair, + X500Principal subject, + Instant notBefore, + Instant notAfter, + SignatureAlgorithm signingAlgorithm, + BigInteger serialNumber) { + return new X509CertificateBuilder(subject, + subject, + notBefore, + notAfter, + keyPair.getPublic(), + keyPair.getPrivate(), + signingAlgorithm, + serialNumber); + } + + /** + * @return generates a cryptographically secure positive serial number up to 128 bits + */ + public static BigInteger generateRandomSerialNumber() { + return new BigInteger(128, new SecureRandom()); + } + + public X509CertificateBuilder addSubjectAlternativeName(String dnsName) { + this.subjectAlternativeNames.add(new SubjectAlternativeName(DNS_NAME, dnsName)); + return this; + } + + public X509CertificateBuilder addSubjectAlternativeName(SubjectAlternativeName san) { + this.subjectAlternativeNames.add(san); + return this; + } + + public X509CertificateBuilder setBasicConstraints(boolean isCritical, boolean isCertAuthorityCertificate) { + this.basicConstraintsExtension = new BasicConstraintsExtension(isCritical, isCertAuthorityCertificate); + return this; + } + + public X509CertificateBuilder setIsCertAuthority(boolean isCertAuthority) { + return setBasicConstraints(true, isCertAuthority); + } + + public X509Certificate build() { + try { + JcaX509v3CertificateBuilder jcaCertBuilder = new JcaX509v3CertificateBuilder( + issuer, serialNumber, Date.from(notBefore), Date.from(notAfter), subject, certPublicKey); + if (basicConstraintsExtension != null) { + jcaCertBuilder.addExtension( + Extension.basicConstraints, + basicConstraintsExtension.isCritical, + new BasicConstraints(basicConstraintsExtension.isCertAuthorityCertificate)); + } + if (!subjectAlternativeNames.isEmpty()) { + GeneralNames generalNames = new GeneralNames( + subjectAlternativeNames.stream() + .map(SubjectAlternativeName::toGeneralName) + .toArray(GeneralName[]::new)); + jcaCertBuilder.addExtension(Extension.subjectAlternativeName, false, generalNames); + } + ContentSigner contentSigner = new JcaContentSignerBuilder(signingAlgorithm.getAlgorithmName()) + .setProvider(BouncyCastleProviderHolder.getInstance()) + .build(caPrivateKey); + return new JcaX509CertificateConverter() + .setProvider(BouncyCastleProviderHolder.getInstance()) + .getCertificate(jcaCertBuilder.build(contentSigner)); + } catch (OperatorException | GeneralSecurityException e) { + throw new RuntimeException(e); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/security/X509CertificateUtils.java b/vespajlib/src/main/java/com/yahoo/security/X509CertificateUtils.java new file mode 100644 index 00000000000..33bd750bac5 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/X509CertificateUtils.java @@ -0,0 +1,136 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.bouncycastle.asn1.ASN1Encodable; +import org.bouncycastle.asn1.ASN1OctetString; +import org.bouncycastle.asn1.ASN1Primitive; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.openssl.PEMParser; +import org.bouncycastle.openssl.jcajce.JcaPEMWriter; +import org.bouncycastle.util.io.pem.PemObject; + +import javax.naming.NamingException; +import javax.naming.ldap.LdapName; +import javax.security.auth.x500.X500Principal; +import java.io.IOException; +import java.io.StringReader; +import java.io.StringWriter; +import java.io.UncheckedIOException; +import java.security.GeneralSecurityException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static com.yahoo.security.Extension.SUBJECT_ALTERNATIVE_NAMES; +import static java.util.stream.Collectors.toList; + +/** + * @author bjorncs + */ +public class X509CertificateUtils { + + private X509CertificateUtils() {} + + public static X509Certificate fromPem(String pem) { + try (PEMParser parser = new PEMParser(new StringReader(pem))) { + return toX509Certificate(parser.readObject()); + } catch (IOException e) { + throw new UncheckedIOException(e); + } catch (CertificateException e) { + throw new RuntimeException(e); + } + } + + public static List<X509Certificate> certificateListFromPem(String pem) { + try (PEMParser parser = new PEMParser(new StringReader(pem))) { + List<X509Certificate> list = new ArrayList<>(); + Object pemObject; + while ((pemObject = parser.readObject()) != null) { + list.add(toX509Certificate(pemObject)); + } + return list; + } catch (IOException e) { + throw new UncheckedIOException(e); + } catch (CertificateException e) { + throw new RuntimeException(e); + } + } + + private static X509Certificate toX509Certificate(Object pemObject) throws CertificateException { + if (pemObject instanceof X509Certificate) { + return (X509Certificate) pemObject; + } + if (pemObject instanceof X509CertificateHolder) { + return new JcaX509CertificateConverter() + .setProvider(BouncyCastleProviderHolder.getInstance()) + .getCertificate((X509CertificateHolder) pemObject); + } + throw new IllegalArgumentException("Invalid type of PEM object: " + pemObject); + } + + public static String toPem(X509Certificate certificate) { + try (StringWriter stringWriter = new StringWriter(); JcaPEMWriter pemWriter = new JcaPEMWriter(stringWriter)) { + pemWriter.writeObject(new PemObject("CERTIFICATE", certificate.getEncoded())); + pemWriter.flush(); + return stringWriter.toString(); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public static String toPem(List<X509Certificate> certificates) { + try (StringWriter stringWriter = new StringWriter(); JcaPEMWriter pemWriter = new JcaPEMWriter(stringWriter)) { + for (X509Certificate certificate : certificates) { + pemWriter.writeObject(new PemObject("CERTIFICATE", certificate.getEncoded())); + } + pemWriter.flush(); + return stringWriter.toString(); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public static List<String> getSubjectCommonNames(X509Certificate certificate) { + return getCommonNames(certificate.getSubjectX500Principal()); + } + + public static List<String> getIssuerCommonNames(X509Certificate certificate) { + return getCommonNames(certificate.getIssuerX500Principal()); + } + + public static List<String> getCommonNames(X500Principal subject) { + try { + String subjectPrincipal = subject.getName(); + return new LdapName(subjectPrincipal).getRdns().stream() + .filter(rdn -> rdn.getType().equalsIgnoreCase("cn")) + .map(rdn -> rdn.getValue().toString()) + .collect(toList()); + } catch (NamingException e) { + throw new IllegalArgumentException("Invalid CN: " + e, e); + } + + } + + public static List<SubjectAlternativeName> getSubjectAlternativeNames(X509Certificate certificate) { + try { + byte[] extensionValue = certificate.getExtensionValue(SUBJECT_ALTERNATIVE_NAMES.getOId()); + if (extensionValue == null) return Collections.emptyList(); + ASN1Encodable asn1Encodable = ASN1Primitive.fromByteArray(extensionValue); + if (asn1Encodable instanceof ASN1OctetString) { + asn1Encodable = ASN1Primitive.fromByteArray(((ASN1OctetString) asn1Encodable).getOctets()); + } + GeneralNames names = GeneralNames.getInstance(asn1Encodable); + return SubjectAlternativeName.fromGeneralNames(names); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/vespajlib/src/main/java/com/yahoo/security/package-info.java b/vespajlib/src/main/java/com/yahoo/security/package-info.java new file mode 100644 index 00000000000..10a4c9c0e0d --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/package-info.java @@ -0,0 +1,9 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * @author bjorncs + */ + +@ExportPackage +package com.yahoo.security; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/vespajlib/src/main/java/com/yahoo/security/tls/TransportSecurityOptions.java b/vespajlib/src/main/java/com/yahoo/security/tls/TransportSecurityOptions.java new file mode 100644 index 00000000000..f0d1edd6889 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/tls/TransportSecurityOptions.java @@ -0,0 +1,90 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Objects; +import java.util.Optional; + +/** + * Generic TLS configuration for Vespa + * + * @author bjorncs + */ +public class TransportSecurityOptions { + + private static final ObjectMapper mapper = new ObjectMapper(); + + private final Path privateKeyFile; + private final Path certificatesFile; + private final Path caCertificatesFile; + + public TransportSecurityOptions(String privateKeyFile, String certificatesFile, String caCertificatesFile) { + this(Paths.get(privateKeyFile), Paths.get(certificatesFile), Paths.get(caCertificatesFile)); + } + + public TransportSecurityOptions(Path privateKeyFile, Path certificatesFile, Path caCertificatesFile) { + this.privateKeyFile = privateKeyFile; + this.certificatesFile = certificatesFile; + this.caCertificatesFile = caCertificatesFile; + } + + public Path getPrivateKeyFile() { + return privateKeyFile; + } + + public Path getCertificatesFile() { + return certificatesFile; + } + + public Path getCaCertificatesFile() { + return caCertificatesFile; + } + + public static TransportSecurityOptions fromJsonFile(Path file) { + try { + JsonNode root = mapper.readTree(file.toFile()); + JsonNode filesNode = getField(root, "files"); + String privateKeyFile = getField(filesNode, "private-key").asText(); + String certificatesFile = getField(filesNode, "certificates").asText(); + String caCertificatesFile = getField(filesNode, "ca-certificates").asText(); + return new TransportSecurityOptions(privateKeyFile, certificatesFile, caCertificatesFile); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static JsonNode getField(JsonNode root, String fieldName) { + return Optional.ofNullable(root.get(fieldName)) + .orElseThrow(() -> new IllegalArgumentException(String.format("'%s' field missing", fieldName))); + } + + @Override + public String toString() { + return "TransportSecurityOptions{" + + "privateKeyFile=" + privateKeyFile + + ", certificatesFile=" + certificatesFile + + ", caCertificatesFile=" + caCertificatesFile + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TransportSecurityOptions that = (TransportSecurityOptions) o; + return Objects.equals(privateKeyFile, that.privateKeyFile) && + Objects.equals(certificatesFile, that.certificatesFile) && + Objects.equals(caCertificatesFile, that.caCertificatesFile); + } + + @Override + public int hashCode() { + return Objects.hash(privateKeyFile, certificatesFile, caCertificatesFile); + } +}
\ No newline at end of file diff --git a/vespajlib/src/main/java/com/yahoo/security/tls/package-info.java b/vespajlib/src/main/java/com/yahoo/security/tls/package-info.java new file mode 100644 index 00000000000..b5668182f14 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/security/tls/package-info.java @@ -0,0 +1,8 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * @author bjorncs + */ +@ExportPackage +package com.yahoo.security.tls; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/KeyStoreBuilderTest.java b/vespajlib/src/test/java/com/yahoo/security/KeyStoreBuilderTest.java new file mode 100644 index 00000000000..06ea5d963a3 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/KeyStoreBuilderTest.java @@ -0,0 +1,55 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.cert.X509Certificate; + +import static com.yahoo.security.TestUtils.createCertificate; +import static com.yahoo.security.TestUtils.createKeystoreFile; + + +/** + * @author bjorncs + */ +public class KeyStoreBuilderTest { + + private static final char[] PASSWORD = new char[0]; + + @Rule + public TemporaryFolder tempDirectory = new TemporaryFolder(); + + @Test + public void can_create_jks_keystore_from_privatekey_and_certificate() throws Exception { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + X509Certificate certificate = createCertificate(keyPair); + KeyStoreBuilder.withType(KeyStoreType.JKS) + .withKeyEntry("key", keyPair.getPrivate(), certificate) + .build(); + } + + @Test + public void can_build_jks_keystore_from_file() throws Exception { + Path keystoreFile = tempDirectory.newFile().toPath(); + createKeystoreFile(keystoreFile, KeyStoreType.JKS, PASSWORD); + + KeyStoreBuilder.withType(KeyStoreType.JKS) + .fromFile(keystoreFile, PASSWORD) + .build(); + } + + @Test + public void can_build_pcks12_keystore_from_file() throws Exception { + Path keystoreFile = tempDirectory.newFile().toPath(); + createKeystoreFile(keystoreFile, KeyStoreType.PKCS12, PASSWORD); + + KeyStoreBuilder.withType(KeyStoreType.PKCS12) + .fromFile(keystoreFile, PASSWORD) + .build(); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/KeyUtilsTest.java b/vespajlib/src/test/java/com/yahoo/security/KeyUtilsTest.java new file mode 100644 index 00000000000..825f4446d94 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/KeyUtilsTest.java @@ -0,0 +1,44 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; + +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.PublicKey; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; + +/** + * @author bjorncs + */ +public class KeyUtilsTest { + + @Test + public void can_extract_public_key_from_rsa_private() { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); + PublicKey publicKey = KeyUtils.extractPublicKey(keyPair.getPrivate()); + assertNotNull(publicKey); + } + + @Test + public void can_extract_public_key_from_ecdsa_private() { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); + PublicKey publicKey = KeyUtils.extractPublicKey(keyPair.getPrivate()); + assertNotNull(publicKey); + } + + @Test + public void can_serialize_deserialize_pem() { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); + String pem = KeyUtils.toPem(keyPair.getPrivate()); + assertThat(pem, containsString("BEGIN RSA PRIVATE KEY")); + assertThat(pem, containsString("END RSA PRIVATE KEY")); + PrivateKey deserializedKey = KeyUtils.fromPemEncodedPrivateKey(pem); + assertEquals(keyPair.getPrivate(), deserializedKey); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrBuilderTest.java b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrBuilderTest.java new file mode 100644 index 00000000000..d51203a5cb2 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrBuilderTest.java @@ -0,0 +1,27 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; + +import javax.security.auth.x500.X500Principal; +import java.security.KeyPair; + +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class Pkcs10CsrBuilderTest { + + @Test + public void can_build_csr_with_sans() { + X500Principal subject = new X500Principal("CN=subject"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, keypair, SignatureAlgorithm.SHA512_WITH_ECDSA) + .addSubjectAlternativeName("san1.com") + .addSubjectAlternativeName("san2.com") + .build(); + assertEquals(subject, csr.getSubject()); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrTest.java b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrTest.java new file mode 100644 index 00000000000..cc1f6cc6a14 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrTest.java @@ -0,0 +1,57 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; + +import javax.security.auth.x500.X500Principal; +import java.security.KeyPair; +import java.util.Arrays; +import java.util.List; + +import static com.yahoo.security.SubjectAlternativeName.Type.DNS_NAME; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * @author bjorncs + */ +public class Pkcs10CsrTest { + + @Test + public void can_read_subject_alternative_names() { + X500Principal subject = new X500Principal("CN=subject"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + SubjectAlternativeName san1 = new SubjectAlternativeName(DNS_NAME, "san1.com"); + SubjectAlternativeName san2 = new SubjectAlternativeName(DNS_NAME, "san2.com"); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, keypair, SignatureAlgorithm.SHA512_WITH_ECDSA) + .addSubjectAlternativeName(san1) + .addSubjectAlternativeName(san2) + .build(); + assertEquals(Arrays.asList(san1, san2), csr.getSubjectAlternativeNames()); + } + + @Test + public void can_read_basic_constraints() { + X500Principal subject = new X500Principal("CN=subject"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, keypair, SignatureAlgorithm.SHA512_WITH_ECDSA) + .setBasicConstraints(true, true) + .build(); + assertTrue(csr.getBasicConstraints().isPresent()); + assertTrue(csr.getBasicConstraints().get()); + } + + @Test + public void can_read_extensions() { + X500Principal subject = new X500Principal("CN=subject"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, keypair, SignatureAlgorithm.SHA512_WITH_ECDSA) + .addSubjectAlternativeName("san") + .setBasicConstraints(true, true) + .build(); + List<String> expected = Arrays.asList(Extension.BASIC_CONSTRAINTS.getOId(), Extension.SUBJECT_ALTERNATIVE_NAMES.getOId()); + List<String> actual = csr.getExtensionOIds(); + assertEquals(expected, actual); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrUtilsTest.java b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrUtilsTest.java new file mode 100644 index 00000000000..04d35a537bb --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/Pkcs10CsrUtilsTest.java @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; + +import javax.security.auth.x500.X500Principal; +import java.security.KeyPair; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +/** + * @author bjorncs + */ +public class Pkcs10CsrUtilsTest { + + @Test + public void can_deserialize_serialized_pem_csr() { + X500Principal subject = new X500Principal("CN=subject"); + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, keypair, SignatureAlgorithm.SHA512_WITH_ECDSA).build(); + String pem = Pkcs10CsrUtils.toPem(csr); + Pkcs10Csr deserializedCsr = Pkcs10CsrUtils.fromPem(pem); + assertThat(pem, containsString("BEGIN CERTIFICATE REQUEST")); + assertThat(pem, containsString("END CERTIFICATE REQUEST")); + assertEquals(subject, deserializedCsr.getSubject()); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/SslContextBuilderTest.java b/vespajlib/src/test/java/com/yahoo/security/SslContextBuilderTest.java new file mode 100644 index 00000000000..cc269a4ef43 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/SslContextBuilderTest.java @@ -0,0 +1,77 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.cert.X509Certificate; + +import static com.yahoo.security.TestUtils.createCertificate; +import static com.yahoo.security.TestUtils.createKeystore; +import static com.yahoo.security.TestUtils.createKeystoreFile; + +/** + * @author bjorncs + */ +public class SslContextBuilderTest { + + private static final char[] PASSWORD = new char[0]; + + @Rule + public TemporaryFolder tempDirectory = new TemporaryFolder(); + + @Test + public void can_build_sslcontext_with_truststore_only() throws Exception { + new SslContextBuilder() + .withTrustStore(createKeystore(KeyStoreType.JKS, PASSWORD)) + .build(); + } + + @Test + public void can_build_sslcontext_with_keystore_only() throws Exception { + new SslContextBuilder() + .withKeyStore(createKeystore(KeyStoreType.JKS, PASSWORD), PASSWORD) + .build(); + } + + @Test + public void can_build_sslcontext_with_truststore_and_keystore() throws Exception { + new SslContextBuilder() + .withKeyStore(createKeystore(KeyStoreType.JKS, PASSWORD), PASSWORD) + .withTrustStore(createKeystore(KeyStoreType.JKS, PASSWORD)) + .build(); + } + + @Test + public void can_build_sslcontext_with_keystore_from_private_key_and_certificate() throws Exception { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + X509Certificate certificate = createCertificate(keyPair); + new SslContextBuilder() + .withKeyStore(keyPair.getPrivate(), certificate) + .build(); + } + + @Test + public void can_build_sslcontext_with_jks_keystore_from_file() throws Exception { + Path keystoreFile = tempDirectory.newFile().toPath(); + createKeystoreFile(keystoreFile, KeyStoreType.JKS, PASSWORD); + + new SslContextBuilder() + .withKeyStore(keystoreFile, PASSWORD, KeyStoreType.JKS) + .build(); + } + + @Test + public void can_build_sslcontext_with_pcks12_keystore_from_file() throws Exception { + Path keystoreFile = tempDirectory.newFile().toPath(); + createKeystoreFile(keystoreFile, KeyStoreType.PKCS12, PASSWORD); + + new SslContextBuilder() + .withKeyStore(keystoreFile, PASSWORD, KeyStoreType.PKCS12) + .build(); + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/security/TestUtils.java b/vespajlib/src/test/java/com/yahoo/security/TestUtils.java new file mode 100644 index 00000000000..fcfcfb2b761 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/TestUtils.java @@ -0,0 +1,42 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.cert.X509Certificate; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import static com.yahoo.security.KeyStoreUtils.writeKeyStoreToFile; + + +/** + * @author bjorncs + */ +class TestUtils { + + static KeyStore createKeystore(KeyStoreType type, char[] password) { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + return KeyStoreBuilder.withType(type) + .withKeyEntry("entry-name", keyPair.getPrivate(), password, createCertificate(keyPair)) + .build(); + } + + static X509Certificate createCertificate(KeyPair keyPair) { + return createCertificate(keyPair, new X500Principal("CN=mysubject")); + } + + static X509Certificate createCertificate(KeyPair keyPair, X500Principal subject) { + return X509CertificateBuilder + .fromKeypair( + keyPair, subject, Instant.now(), Instant.now().plus(1, ChronoUnit.DAYS), SignatureAlgorithm.SHA512_WITH_ECDSA, BigInteger.valueOf(1)) + .build(); + } + + static void createKeystoreFile(Path file, KeyStoreType type, char[] password) { + writeKeyStoreToFile(createKeystore(type, password), file, password); + } +} diff --git a/vespajlib/src/test/java/com/yahoo/security/X509CertificateBuilderTest.java b/vespajlib/src/test/java/com/yahoo/security/X509CertificateBuilderTest.java new file mode 100644 index 00000000000..7e6d343b570 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/X509CertificateBuilderTest.java @@ -0,0 +1,83 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.cert.X509Certificate; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collection; + +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +@RunWith(Parameterized.class) +public class X509CertificateBuilderTest { + + @Parameterized.Parameters(name = "{0}") + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {KeyAlgorithm.RSA, 2048, SignatureAlgorithm.SHA512_WITH_RSA}, + {KeyAlgorithm.EC, 256, SignatureAlgorithm.SHA512_WITH_ECDSA}}); + } + + private final KeyAlgorithm keyAlgorithm; + private final int keySize; + private final SignatureAlgorithm signatureAlgorithm; + + public X509CertificateBuilderTest(KeyAlgorithm keyAlgorithm, + int keySize, + SignatureAlgorithm signatureAlgorithm) { + this.keyAlgorithm = keyAlgorithm; + this.keySize = keySize; + this.signatureAlgorithm = signatureAlgorithm; + } + + @Test + public void can_build_self_signed_certificate() { + KeyPair keyPair = KeyUtils.generateKeypair(keyAlgorithm, keySize); + X500Principal subject = new X500Principal("CN=myservice"); + X509Certificate cert = + X509CertificateBuilder.fromKeypair( + keyPair, + subject, + Instant.now(), + Instant.now().plus(1, ChronoUnit.DAYS), + signatureAlgorithm, + BigInteger.valueOf(1)) + .setBasicConstraints(true, true) + .build(); + assertEquals(subject, cert.getSubjectX500Principal()); + } + + @Test + public void can_build_certificate_from_csr() { + X500Principal subject = new X500Principal("CN=subject"); + X500Principal issuer = new X500Principal("CN=issuer"); + KeyPair csrKeypair = KeyUtils.generateKeypair(keyAlgorithm, keySize); + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(subject, csrKeypair, signatureAlgorithm).build(); + KeyPair caKeypair = KeyUtils.generateKeypair(keyAlgorithm, keySize); + X509Certificate cert = X509CertificateBuilder + .fromCsr( + csr, + issuer, + Instant.now(), + Instant.now().plus(1, ChronoUnit.DAYS), + caKeypair.getPrivate(), + signatureAlgorithm, + BigInteger.valueOf(1)) + .addSubjectAlternativeName("subject1.alt") + .addSubjectAlternativeName("subject2.alt") + .build(); + assertEquals(subject, cert.getSubjectX500Principal()); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/X509CertificateUtilsTest.java b/vespajlib/src/test/java/com/yahoo/security/X509CertificateUtilsTest.java new file mode 100644 index 00000000000..76a93028efe --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/X509CertificateUtilsTest.java @@ -0,0 +1,74 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security; + +import org.junit.Test; + +import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.cert.X509Certificate; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.List; + +import static com.yahoo.security.SubjectAlternativeName.Type.DNS_NAME; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +/** + * @author bjorncs + */ +public class X509CertificateUtilsTest { + @Test + public void can_deserialize_serialized_pem_certificate() { + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + X500Principal subject = new X500Principal("CN=myservice"); + X509Certificate cert = TestUtils.createCertificate(keypair, subject); + assertEquals(subject, cert.getSubjectX500Principal()); + String pem = X509CertificateUtils.toPem(cert); + assertThat(pem, containsString("BEGIN CERTIFICATE")); + assertThat(pem, containsString("END CERTIFICATE")); + X509Certificate deserializedCert = X509CertificateUtils.fromPem(pem); + assertEquals(subject, deserializedCert.getSubjectX500Principal()); + } + + @Test + public void can_deserialize_serialized_pem_certificate_list() { + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + X500Principal subject1 = new X500Principal("CN=myservice1"); + X509Certificate cert1 = TestUtils.createCertificate(keypair, subject1); + X500Principal subject2 = new X500Principal("CN=myservice2"); + X509Certificate cert2 = TestUtils.createCertificate(keypair, subject2); + List<X509Certificate> certificateList = Arrays.asList(cert1, cert2); + String pem = X509CertificateUtils.toPem(certificateList); + List<X509Certificate> deserializedCertificateList = X509CertificateUtils.certificateListFromPem(pem); + assertEquals(2, certificateList.size()); + assertEquals(subject1, deserializedCertificateList.get(0).getSubjectX500Principal()); + assertEquals(subject2, deserializedCertificateList.get(1).getSubjectX500Principal()); + } + + @Test + public void can_list_subject_alternative_names() { + KeyPair keypair = KeyUtils.generateKeypair(KeyAlgorithm.EC, 256); + X500Principal subject = new X500Principal("CN=myservice"); + SubjectAlternativeName san = new SubjectAlternativeName(DNS_NAME, "dns-san"); + X509Certificate cert = X509CertificateBuilder + .fromKeypair( + keypair, + subject, + Instant.now(), + Instant.now().plus(1, ChronoUnit.DAYS), + SignatureAlgorithm.SHA512_WITH_ECDSA, + BigInteger.valueOf(1)) + .addSubjectAlternativeName(san) + .build(); + + List<SubjectAlternativeName> sans = X509CertificateUtils.getSubjectAlternativeNames(cert); + assertThat(sans.size(), is(1)); + assertThat(sans.get(0), equalTo(san)); + } +}
\ No newline at end of file diff --git a/vespajlib/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java b/vespajlib/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java new file mode 100644 index 00000000000..ad80c52ae2a --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java @@ -0,0 +1,24 @@ +package com.yahoo.security.tls; + +import org.junit.Test; + +import java.nio.file.Path; +import java.nio.file.Paths; + +import static org.junit.Assert.*; + +/** + * @author bjorncs + */ +public class TransportSecurityOptionsTest { + + private static final Path TEST_CONFIG_FILE = Paths.get("src/test/resources/transport-security-options.json"); + + @Test + public void can_read_options_from_json_file() { + TransportSecurityOptions expectedOptions = new TransportSecurityOptions("myhost.key", "certs.pem", "my_cas.pem"); + TransportSecurityOptions actualOptions = TransportSecurityOptions.fromJsonFile(TEST_CONFIG_FILE); + assertEquals(expectedOptions, actualOptions); + } + +}
\ No newline at end of file diff --git a/vespajlib/src/test/resources/transport-security-options.json b/vespajlib/src/test/resources/transport-security-options.json new file mode 100644 index 00000000000..0506c130722 --- /dev/null +++ b/vespajlib/src/test/resources/transport-security-options.json @@ -0,0 +1,7 @@ +{ + "files": { + "private-key": "myhost.key", + "ca-certificates": "my_cas.pem", + "certificates": "certs.pem" + } +}
\ No newline at end of file diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index 33553da9422..4ae98be29b6 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -56,6 +56,8 @@ vespa_define_module( src/tests/net/send_fd src/tests/net/socket src/tests/net/socket_spec + src/tests/net/tls/openssl_impl + src/tests/net/tls/transport_options src/tests/objects/nbostream src/tests/optimized src/tests/printable @@ -118,6 +120,8 @@ vespa_define_module( src/vespa/vespalib/io src/vespa/vespalib/locale src/vespa/vespalib/net + src/vespa/vespalib/net/tls + src/vespa/vespalib/net/tls/impl src/vespa/vespalib/objects src/vespa/vespalib/stllike src/vespa/vespalib/test diff --git a/vespalib/src/tests/net/tls/openssl_impl/CMakeLists.txt b/vespalib/src/tests/net/tls/openssl_impl/CMakeLists.txt new file mode 100644 index 00000000000..799e2291d7c --- /dev/null +++ b/vespalib/src/tests/net/tls/openssl_impl/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_net_tls_openssl_impl_test_app TEST + SOURCES + openssl_impl_test.cpp + DEPENDS + vespalib +) +vespa_add_test(NAME vespalib_net_tls_openssl_impl_test_app COMMAND vespalib_net_tls_openssl_impl_test_app) + diff --git a/vespalib/src/tests/net/tls/openssl_impl/openssl_impl_test.cpp b/vespalib/src/tests/net/tls/openssl_impl/openssl_impl_test.cpp new file mode 100644 index 00000000000..4e8bf31e75e --- /dev/null +++ b/vespalib/src/tests/net/tls/openssl_impl/openssl_impl_test.cpp @@ -0,0 +1,134 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/net/tls/tls_context.h> +#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <vespa/vespalib/net/tls/crypto_codec.h> +#include <vespa/vespalib/test/make_tls_options_for_testing.h> +#include <iostream> +#include <stdlib.h> + +using namespace vespalib; +using namespace vespalib::net::tls; + +const char* decode_state_to_str(DecodeResult::State state) noexcept { + switch (state) { + case DecodeResult::State::Failed: return "Broken"; + case DecodeResult::State::OK: return "OK"; + case DecodeResult::State::NeedsMorePeerData: return "NeedsMorePeerData"; + default: + abort(); + } +} + +const char* hs_state_to_str(HandshakeResult::State state) noexcept { + switch (state) { + case HandshakeResult::State::Failed: return "Broken"; + case HandshakeResult::State::Done: return "Done"; + case HandshakeResult::State::NeedsMorePeerData: return "NeedsMorePeerData"; + default: + abort(); + } +} + +void log_handshake_result(const char* mode, const HandshakeResult& res) { + fprintf(stderr, "(handshake) %s consumed %zu peer bytes, wrote %zu peer bytes. State: %s\n", + mode, res.bytes_consumed, res.bytes_produced, + hs_state_to_str(res.state)); +} + +void log_encode_result(const char* mode, const EncodeResult& res) { + fprintf(stderr, "(encode) %s read %zu plaintext, wrote %zu cipher. State: %s\n", + mode, res.bytes_consumed, res.bytes_produced, + res.failed ? "Broken! D:" : "OK"); +} + +void log_decode_result(const char* mode, const DecodeResult& res) { + fprintf(stderr, "(decode) %s read %zu cipher, wrote %zu plaintext. State: %s\n", + mode, res.bytes_consumed, res.bytes_produced, + decode_state_to_str(res.state)); +} + +bool complete_handshake(CryptoCodec& client, CryptoCodec& server) { + // Not using vespalib::string here since it doesn't have erase(iter, length) implemented. + std::string client_to_server_buf; + std::string server_to_client_buf; + + HandshakeResult cli_res; + HandshakeResult serv_res; + while (!(cli_res.done() && serv_res.done())) { + client_to_server_buf.resize(client.min_encode_buffer_size()); + server_to_client_buf.resize(server.min_encode_buffer_size()); + + cli_res = client.handshake(server_to_client_buf.data(), serv_res.bytes_produced, + client_to_server_buf.data(), client_to_server_buf.size()); + log_handshake_result("client", cli_res); + server_to_client_buf.erase(server_to_client_buf.begin(), server_to_client_buf.begin() + cli_res.bytes_consumed); + + serv_res = server.handshake(client_to_server_buf.data(), cli_res.bytes_produced, + server_to_client_buf.data(), server_to_client_buf.size()); + log_handshake_result("server", serv_res); + client_to_server_buf.erase(client_to_server_buf.begin(), client_to_server_buf.begin() + serv_res.bytes_consumed); + + if (cli_res.failed() || serv_res.failed()) { + return false; + } + } + return true; +} + +TEST("client and server can complete handshake") { + // TODO move to fixture + auto tls_opts = vespalib::test::make_tls_options_for_testing(); + auto tls_ctx = TlsContext::create_default_context(tls_opts); + auto client = CryptoCodec::create_default_codec(*tls_ctx, CryptoCodec::Mode::Client); + auto server = CryptoCodec::create_default_codec(*tls_ctx, CryptoCodec::Mode::Server); + + EXPECT_TRUE(complete_handshake(*client, *server)); +} + +TEST("client can send single data frame to server after handshake") { + // TODO move to fixture + auto tls_opts = vespalib::test::make_tls_options_for_testing(); + auto tls_ctx = TlsContext::create_default_context(tls_opts); + auto client = CryptoCodec::create_default_codec(*tls_ctx, CryptoCodec::Mode::Client); + auto server = CryptoCodec::create_default_codec(*tls_ctx, CryptoCodec::Mode::Server); + + ASSERT_TRUE(complete_handshake(*client, *server)); + + std::string client_to_server_buf; + client_to_server_buf.resize(client->min_encode_buffer_size()); + + std::string client_plaintext = "Hellooo world! :D"; + auto cli_res = client->encode(client_plaintext.data(), client_plaintext.size(), + client_to_server_buf.data(), client_to_server_buf.size()); + log_encode_result("client", cli_res); + + std::string server_plaintext_out; + server_plaintext_out.resize(server->min_decode_buffer_size()); + auto serv_res = server->decode(client_to_server_buf.data(), cli_res.bytes_produced, + server_plaintext_out.data(), server_plaintext_out.size()); + log_decode_result("server", serv_res); + + ASSERT_FALSE(cli_res.failed); + ASSERT_FALSE(serv_res.failed()); + + ASSERT_TRUE(serv_res.state == DecodeResult::State::OK); + std::string data_received(server_plaintext_out.data(), serv_res.bytes_produced); + EXPECT_EQUAL(client_plaintext, data_received); +} + +/* + * TODO tests: + * - full duplex read/write + * - read and write of > frame size data + * - handshakes with multi frame writes + * - completed handshake with pipelined data frame + * - short ciphertext reads on decode + * - short plaintext writes on decode (.. if we even want to support this..) + * - short ciphertext write on encode + * - peer certificate validation on server + * - peer certificate validation on client + * - detection of peer shutdown session + */ + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/tests/net/tls/transport_options/CMakeLists.txt b/vespalib/src/tests/net/tls/transport_options/CMakeLists.txt new file mode 100644 index 00000000000..ee1e2477708 --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/CMakeLists.txt @@ -0,0 +1,10 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_net_tls_transport_options_test_app TEST + SOURCES + transport_options_reading_test.cpp + DEPENDS + vespalib +) +vespa_add_test(NAME vespalib_net_tls_transport_options_test_app + COMMAND vespalib_net_tls_transport_options_test_app) + diff --git a/vespalib/src/tests/net/tls/transport_options/dummy_ca_certs.txt b/vespalib/src/tests/net/tls/transport_options/dummy_ca_certs.txt new file mode 100644 index 00000000000..b617f6f17e4 --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/dummy_ca_certs.txt @@ -0,0 +1 @@ +My CA certificates diff --git a/vespalib/src/tests/net/tls/transport_options/dummy_certs.txt b/vespalib/src/tests/net/tls/transport_options/dummy_certs.txt new file mode 100644 index 00000000000..088b91ff770 --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/dummy_certs.txt @@ -0,0 +1 @@ +My certificate chain diff --git a/vespalib/src/tests/net/tls/transport_options/dummy_privkey.txt b/vespalib/src/tests/net/tls/transport_options/dummy_privkey.txt new file mode 100644 index 00000000000..f29585fe31f --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/dummy_privkey.txt @@ -0,0 +1 @@ +My private key diff --git a/vespalib/src/tests/net/tls/transport_options/ok_config.json b/vespalib/src/tests/net/tls/transport_options/ok_config.json new file mode 100644 index 00000000000..dd2591661dc --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/ok_config.json @@ -0,0 +1,7 @@ +{ + "files":{ + "private-key": "dummy_privkey.txt", + "ca-certificates": "dummy_ca_certs.txt", + "certificates": "dummy_certs.txt" + } +} diff --git a/vespalib/src/tests/net/tls/transport_options/transport_options_reading_test.cpp b/vespalib/src/tests/net/tls/transport_options/transport_options_reading_test.cpp new file mode 100644 index 00000000000..1ce4a4353d0 --- /dev/null +++ b/vespalib/src/tests/net/tls/transport_options/transport_options_reading_test.cpp @@ -0,0 +1,65 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/vespalib/io/fileutil.h> +#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <vespa/vespalib/net/tls/transport_security_options_reading.h> +#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/util/exceptions.h> + +using namespace vespalib; +using namespace vespalib::net::tls; + +TEST("can load TLS credentials via config file") { + auto opts = read_options_from_json_file("ok_config.json"); + ASSERT_TRUE(opts.get() != nullptr); + // Obviously we'd need to change this to actual PEM data if config reading started + // actually verifying the _content_ of files, not just reading them. + EXPECT_EQUAL("My private key\n", opts->private_key_pem()); + EXPECT_EQUAL("My CA certificates\n", opts->ca_certs_pem()); + EXPECT_EQUAL("My certificate chain\n", opts->cert_chain_pem()); +} + +TEST("missing JSON file throws exception") { + EXPECT_EXCEPTION(read_options_from_json_file("missing_config.json"), IllegalArgumentException, + "TLS config file 'missing_config.json' could not be read"); +} + +TEST("bad JSON content throws exception") { + const char* bad_json = "hello world :D"; + EXPECT_EXCEPTION(read_options_from_json_string(bad_json), IllegalArgumentException, + "Provided TLS config file is not valid JSON"); +} + +TEST("missing 'files' field throws exception") { + const char* incomplete_json = R"({})"; + EXPECT_EXCEPTION(read_options_from_json_string(incomplete_json), IllegalArgumentException, + "TLS config root field 'files' is missing or empty"); +} + +TEST("missing 'private-key' field throws exception") { + const char* incomplete_json = R"({"files":{"certificates":"dummy_certs.txt","ca-certificates":"dummy_ca_certs.txt"}})"; + EXPECT_EXCEPTION(read_options_from_json_string(incomplete_json), IllegalArgumentException, + "TLS config field 'private-key' has not been set"); +} + +TEST("missing 'certificates' field throws exception") { + const char* incomplete_json = R"({"files":{"private-key":"dummy_privkey.txt","ca-certificates":"dummy_ca_certs.txt"}})"; + EXPECT_EXCEPTION(read_options_from_json_string(incomplete_json), IllegalArgumentException, + "TLS config field 'certificates' has not been set"); +} + +TEST("missing 'ca-certificates' field throws exception") { + const char* incomplete_json = R"({"files":{"private-key":"dummy_privkey.txt","certificates":"dummy_certs.txt"}})"; + EXPECT_EXCEPTION(read_options_from_json_string(incomplete_json), IllegalArgumentException, + "TLS config field 'ca-certificates' has not been set"); +} + +TEST("missing file referenced by field throws exception") { + const char* incomplete_json = R"({"files":{"private-key":"missing_privkey.txt", + "certificates":"dummy_certs.txt", + "ca-certificates":"dummy_ca_certs.txt"}})"; + EXPECT_EXCEPTION(read_options_from_json_string(incomplete_json), IllegalArgumentException, + "File 'missing_privkey.txt' referenced by TLS config does not exist"); +} + +TEST_MAIN() { TEST_RUN_ALL(); } + diff --git a/vespalib/src/vespa/vespalib/CMakeLists.txt b/vespalib/src/vespa/vespalib/CMakeLists.txt index 480caf8f28d..8261bb8874e 100644 --- a/vespalib/src/vespa/vespalib/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/CMakeLists.txt @@ -9,8 +9,11 @@ vespa_add_library(vespalib $<TARGET_OBJECTS:vespalib_vespalib_io> $<TARGET_OBJECTS:vespalib_vespalib_locale> $<TARGET_OBJECTS:vespalib_vespalib_net> + $<TARGET_OBJECTS:vespalib_vespalib_net_tls> + $<TARGET_OBJECTS:vespalib_vespalib_net_tls_impl> $<TARGET_OBJECTS:vespalib_vespalib_objects> $<TARGET_OBJECTS:vespalib_vespalib_stllike> + $<TARGET_OBJECTS:vespalib_vespalib_test> $<TARGET_OBJECTS:vespalib_vespalib_testkit> $<TARGET_OBJECTS:vespalib_vespalib_text> $<TARGET_OBJECTS:vespalib_vespalib_time> @@ -20,6 +23,7 @@ vespa_add_library(vespalib $<TARGET_OBJECTS:vespalib_vespalib_xxhash> INSTALL lib64 DEPENDS - vespalib_vespalib_test gcc ) + +vespa_add_target_package_dependency(vespalib OpenSSL) diff --git a/vespalib/src/vespa/vespalib/net/crypto_engine.cpp b/vespalib/src/vespa/vespalib/net/crypto_engine.cpp index 8832b4b1cfe..38a91456cba 100644 --- a/vespalib/src/vespa/vespalib/net/crypto_engine.cpp +++ b/vespalib/src/vespa/vespalib/net/crypto_engine.cpp @@ -5,6 +5,10 @@ #include <chrono> #include <thread> #include <vespa/vespalib/xxhash/xxhash.h> +#include <vespa/vespalib/stllike/string.h> +#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <vespa/vespalib/net/tls/transport_security_options_reading.h> +#include <vespa/vespalib/net/tls/tls_crypto_engine.h> #include <assert.h> namespace vespalib { @@ -156,9 +160,13 @@ public: }; CryptoEngine::SP create_default_crypto_engine() { - // TODO: check VESPA_TLS_CONFIG_FILE here - // return std::make_shared<XorCryptoEngine>(); - return std::make_shared<NullCryptoEngine>(); + const char *env = getenv("VESPA_TLS_CONFIG_FILE"); + vespalib::string cfg_file = env ? env : ""; + if (cfg_file.empty()) { + return std::make_shared<NullCryptoEngine>(); + } + auto tls_opts = net::tls::read_options_from_json_file(cfg_file); + return std::make_shared<TlsCryptoEngine>(*tls_opts); } } // namespace vespalib::<unnamed> diff --git a/vespalib/src/vespa/vespalib/net/crypto_socket.h b/vespalib/src/vespa/vespalib/net/crypto_socket.h index 7fe7871960f..f78f7fc0ce7 100644 --- a/vespalib/src/vespa/vespalib/net/crypto_socket.h +++ b/vespalib/src/vespa/vespalib/net/crypto_socket.h @@ -74,13 +74,16 @@ struct CryptoSocket { virtual ssize_t write(const char *buf, size_t len) = 0; /** - * Try to flush data in the write pipeline that is not depenedent + * Try to flush data in the write pipeline that is not dependent * on data not yet written by the application into the underlying * socket. This is to enable the application to identify pending * work that may not be completed until the underlying socket is * ready for writing more data. The semantics are the same as with * a normal socket write (errno, etc.) with the exception that 0 - * will be returned when there is no more data to flush. + * will be returned when there is no more data to flush and any + * positive number indicates that we were able to flush something + * (it does not need to reflect the actual number of bytes written + * to the underlying socket). **/ virtual ssize_t flush() = 0; diff --git a/vespalib/src/vespa/vespalib/net/tls/CMakeLists.txt b/vespalib/src/vespa/vespalib/net/tls/CMakeLists.txt new file mode 100644 index 00000000000..2d34a3e1c80 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_library(vespalib_vespalib_net_tls OBJECT + SOURCES + crypto_codec.cpp + crypto_codec_adapter.cpp + crypto_exception.cpp + tls_context.cpp + tls_crypto_engine.cpp + transport_security_options.cpp + transport_security_options_reading.cpp + DEPENDS +) +find_package(OpenSSL) +target_include_directories(vespalib_vespalib_net_tls PUBLIC ${OPENSSL_INCLUDE_DIR}) + diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_codec.cpp b/vespalib/src/vespa/vespalib/net/tls/crypto_codec.cpp new file mode 100644 index 00000000000..b36913d20e3 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_codec.cpp @@ -0,0 +1,15 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "crypto_codec.h" +#include <vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h> +#include <vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h> +#include <cassert> + +namespace vespalib::net::tls { + +std::unique_ptr<CryptoCodec> CryptoCodec::create_default_codec(TlsContext& ctx, Mode mode) { + auto* ssl_ctx = dynamic_cast<impl::OpenSslTlsContextImpl*>(&ctx); + assert(ssl_ctx != nullptr); + return std::make_unique<impl::OpenSslCryptoCodecImpl>(*ssl_ctx->native_context(), mode); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_codec.h b/vespalib/src/vespa/vespalib/net/tls/crypto_codec.h new file mode 100644 index 00000000000..6e690c809a5 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_codec.h @@ -0,0 +1,124 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <memory> + +namespace vespalib::net::tls { + +struct HandshakeResult { + // Handshake bytes consumed from peer. + size_t bytes_consumed = 0; + // Handshake bytes produced that must be sent to the peer. + size_t bytes_produced = 0; + enum class State { + Failed, + Done, + NeedsMorePeerData + }; + State state = State::Failed; + + bool failed() const noexcept { return (state == State::Failed); } + bool done() const noexcept { return (state == State::Done); } +}; + +struct EncodeResult { + // Plaintext bytes consumed + size_t bytes_consumed = 0; + // Ciphertext bytes produced that must be sent to the peer + size_t bytes_produced = 0; + bool failed = true; +}; + +struct DecodeResult { + // Ciphertext bytes consumed from peer + size_t bytes_consumed = 0; + // Plaintext bytes produced. + size_t bytes_produced = 0; + enum class State { + Failed, + OK, + NeedsMorePeerData + // TODO add Closed/Shutdown as own state? + }; + State state = State::Failed; + + bool failed() const noexcept { return (state == State::Failed); } +}; + +class TlsContext; + +// TODO move to different namespace, not dependent on TLS? + +/* + * A CryptoCodec provides a fully transport-independent way of negotiating + * a secure, authenticated session towards another peer. The codec requires + * the caller to handle any and all actual data transfer + */ +class CryptoCodec { +public: + enum class Mode { + Client, Server + }; + + virtual ~CryptoCodec() = default; + + /* + * Minimum buffer size required to represent one wire format frame + * of encrypted (ciphertext) data, including frame overhead. + */ + virtual size_t min_encode_buffer_size() const noexcept = 0; + /* + * Minimum buffer size required to represent the decoded (plaintext) + * output of a single frame of encrypted data. + */ + virtual size_t min_decode_buffer_size() const noexcept = 0; + + /* + * Precondition: to_peer_buf_size >= min_encode_buffer_size() + * Postcondition: if result.done(), the handshake process has completed + * and data may be passed through encode()/decode(). + */ + virtual HandshakeResult handshake(const char* from_peer, size_t from_peer_buf_size, + char* to_peer, size_t to_peer_buf_size) noexcept = 0; + + /* + * Encodes a single ciphertext frame into `ciphertext`. If plaintext_size + * is greater than can fit into a frame, the returned result's consumed_bytes + * field will be < plaintext_size. The number of actual ciphertext bytes produced + * is available in the returned result's produced_bytes field. + * + * Precondition: handshake must be completed + * Precondition: ciphertext_size >= min_encode_buffer_size(), i.e. it must be + * possible to encode at least 1 frame. + * Postcondition: if plaintext_size > 0 and result.failed == false, a single + * frame of ciphertext has been written into the to_peer buffer. + * Size of written frame is given by result.bytes_produced. This + * includes all protocol-specific frame overhead. + */ + virtual EncodeResult encode(const char* plaintext, size_t plaintext_size, + char* ciphertext, size_t ciphertext_size) noexcept = 0; + /* + * Attempt to decode ciphertext sent by the peer into plaintext. Since + * ciphertext is sent in frames, it's possible that invoking decode() + * may produce a CodecResult with a state of `NeedsMorePeerData` if a + * complete frame is not present in `ciphertext`. In this case, decode() + * must be called again once more data is available. + * + * Precondition: handshake must be completed + * Precondition: plaintext_size >= min_decode_buffer_size() + * Postcondition: if result.state == DecodeResult::State::OK, at least 1 + * complete frame has been written to the `plaintext` buffer + */ + virtual DecodeResult decode(const char* ciphertext, size_t ciphertext_size, + char* plaintext, size_t plaintext_size) noexcept = 0; + + /* + * Creates an implementation defined CryptoCodec that provides at least TLSv1.2 + * compliant handshaking and full duplex data transfer. + * + * Throws CryptoException if resources cannot be allocated for the codec. + */ + static std::unique_ptr<CryptoCodec> create_default_codec(TlsContext& ctx, Mode mode); +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.cpp b/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.cpp new file mode 100644 index 00000000000..435f16cc340 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.cpp @@ -0,0 +1,146 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "crypto_codec_adapter.h" +#include <assert.h> + +namespace vespalib::net::tls { + +CryptoSocket::HandshakeResult +CryptoCodecAdapter::hs_try_flush() +{ + auto flush_res = flush_all(); + if (flush_res == 0) { + return HandshakeResult::DONE; + } else if (is_blocked(flush_res, errno)) { + return HandshakeResult::NEED_WRITE; + } else { + return HandshakeResult::FAIL; + } +} + +CryptoSocket::HandshakeResult +CryptoCodecAdapter::hs_try_fill() +{ + auto fill_res = fill_input(); + if (fill_res > 0) { + return HandshakeResult::DONE; + } else if (is_blocked(fill_res, errno)) { + return HandshakeResult::NEED_READ; + } else { // eof included here + return HandshakeResult::FAIL; + } +} + +ssize_t +CryptoCodecAdapter::fill_input() +{ + if (_input.get().size < _codec->min_encode_buffer_size()) { + auto dst = _input.reserve(_codec->min_encode_buffer_size()); + ssize_t res = _socket.read(dst.data, dst.size); + if (res > 0) { + _input.commit(res); + } else { + return res; // eof/error + } + } + return 1; // progress +} + +ssize_t +CryptoCodecAdapter::flush_all() +{ + ssize_t res = flush(); + while (res > 0) { + res = flush(); + } + return res; +} + +CryptoSocket::HandshakeResult +CryptoCodecAdapter::handshake() +{ + for (;;) { + auto in = _input.obtain(); + auto out = _output.reserve(_codec->min_encode_buffer_size()); + auto hs_res = _codec->handshake(in.data, in.size, out.data, out.size); + _input.evict(hs_res.bytes_consumed); + _output.commit(hs_res.bytes_produced); + switch (hs_res.state) { + case ::vespalib::net::tls::HandshakeResult::State::Failed: return HandshakeResult::FAIL; + case ::vespalib::net::tls::HandshakeResult::State::Done: return hs_try_flush(); + case ::vespalib::net::tls::HandshakeResult::State::NeedsMorePeerData: + auto flush_res = hs_try_flush(); + if (flush_res != HandshakeResult::DONE) { + return flush_res; + } + auto fill_res = hs_try_fill(); + if (fill_res != HandshakeResult::DONE) { + return fill_res; + } + } + } + return HandshakeResult::DONE; +} + +ssize_t +CryptoCodecAdapter::read(char *buf, size_t len) +{ + auto fill_res = fill_input(); + if (fill_res <= 0) { + return fill_res; + } + ssize_t res = drain(buf, len); + if (res != 0) { + return res; + } + errno = EWOULDBLOCK; + return -1; +} + +ssize_t +CryptoCodecAdapter::drain(char *buf, size_t len) +{ + auto src = _input.obtain(); + auto res = _codec->decode(src.data, src.size, buf, len); + if (res.failed()) { + errno = EIO; + return -1; + } + _input.evict(res.bytes_consumed); + return res.bytes_produced; +} + +ssize_t +CryptoCodecAdapter::write(const char *buf, size_t len) +{ + if (flush_all() < 0) { + return -1; + } + auto dst = _output.reserve(_codec->min_encode_buffer_size()); + auto res = _codec->encode(buf, len, dst.data, dst.size); + if (res.failed) { + errno = EIO; + return -1; + } + _output.commit(res.bytes_produced); + return res.bytes_consumed; +} + +ssize_t +CryptoCodecAdapter::flush() +{ + auto pending = _output.obtain(); + if (pending.size > 0) { + ssize_t res = _socket.write(pending.data, pending.size); + if (res > 0) { + _output.evict(res); + return 1; // progress + } else { + assert(res < 0); + return -1; // error + } + } + return 0; // done +} + +} // namespace vespalib::net::tls diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.h b/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.h new file mode 100644 index 00000000000..6a624ca44f7 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_codec_adapter.h @@ -0,0 +1,46 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/net/crypto_socket.h> +#include <vespa/vespalib/net/socket_handle.h> +#include <vespa/vespalib/data/simple_buffer.h> +#include "crypto_codec.h" + +namespace vespalib::net::tls { + +/** + * Component adapting an underlying CryptoCodec to the CryptoSocket + * interface by performing buffer and socket management. + * + * NOTE: initial implementation is for functionality/proof-of-concept + * purposes, not performance. + **/ +class CryptoCodecAdapter : public CryptoSocket +{ +private: + SimpleBuffer _input; + SimpleBuffer _output; + SocketHandle _socket; + std::unique_ptr<CryptoCodec> _codec; + + bool is_blocked(ssize_t res, int error) const { + return ((res < 0) && ((error == EWOULDBLOCK) || (error == EAGAIN))); + } + HandshakeResult hs_try_flush(); + HandshakeResult hs_try_fill(); + ssize_t fill_input(); // -1/0/1 -> error/eof/ok + ssize_t flush_all(); // -1/0 -> error/ok +public: + CryptoCodecAdapter(SocketHandle socket, std::unique_ptr<CryptoCodec> codec) + : _socket(std::move(socket)), _codec(std::move(codec)) {} + int get_fd() const override { return _socket.get(); } + HandshakeResult handshake() override; + size_t min_read_buffer_size() const override { return _codec->min_decode_buffer_size(); } + ssize_t read(char *buf, size_t len) override; + ssize_t drain(char *, size_t) override; + ssize_t write(const char *buf, size_t len) override; + ssize_t flush() override; +}; + +} // namespace vespalib::net::tls diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_exception.cpp b/vespalib/src/vespa/vespalib/net/tls/crypto_exception.cpp new file mode 100644 index 00000000000..41bb2060c04 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_exception.cpp @@ -0,0 +1,10 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "crypto_exception.h" + +namespace vespalib::net::tls { + +VESPA_IMPLEMENT_EXCEPTION(CryptoException, Exception); + +} + diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_exception.h b/vespalib/src/vespa/vespalib/net/tls/crypto_exception.h new file mode 100644 index 00000000000..696a158e058 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_exception.h @@ -0,0 +1,10 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <vespa/vespalib/util/exception.h> + +namespace vespalib::net::tls { + +VESPA_DEFINE_EXCEPTION(CryptoException, Exception); + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/CMakeLists.txt b/vespalib/src/vespa/vespalib/net/tls/impl/CMakeLists.txt new file mode 100644 index 00000000000..a5a8e8d3eb9 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/CMakeLists.txt @@ -0,0 +1,10 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_library(vespalib_vespalib_net_tls_impl OBJECT + SOURCES + openssl_tls_context_impl.cpp + openssl_crypto_codec_impl.cpp + DEPENDS +) +find_package(OpenSSL) +target_include_directories(vespalib_vespalib_net_tls_impl PUBLIC ${OPENSSL_INCLUDE_DIR}) + diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.cpp b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.cpp new file mode 100644 index 00000000000..a563a43baac --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.cpp @@ -0,0 +1,383 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "openssl_crypto_codec_impl.h" +#include "openssl_tls_context_impl.h" +#include <vespa/vespalib/net/tls/crypto_codec.h> +#include <vespa/vespalib/net/tls/crypto_exception.h> +#include <mutex> +#include <vector> +#include <memory> +#include <stdexcept> +#include <openssl/ssl.h> +#include <openssl/crypto.h> +#include <openssl/err.h> +#include <openssl/pem.h> + +#include <vespa/log/log.h> +LOG_SETUP(".vespalib.net.tls.openssl_crypto_codec_impl"); + +#if (OPENSSL_VERSION_NUMBER < 0x10000000L) +// < 1.0 requires explicit thread ID callback support. +# error "Provided OpenSSL version is too darn old, need at least 1.0" +#endif + +/* + * Beware all ye who dare enter, for this is OpenSSL integration territory. + * Dragons are known to roam the skies. Strange whispers are heard at night + * in the mist-covered lands where the forest meets the lake. Rumors of a + * tome that contains best practices and excellent documentation are heard + * at the local inn, but no one seems to know where it exists, or even if + * it ever existed. Be it best that people carry on with their lives and + * pretend to not know of the beasts that lurk beyond where the torch's + * light fades and turns to all-enveloping darkness. + */ + +namespace vespalib::net::tls::impl { + +namespace { + +bool verify_buf(const char *buf, size_t len) { + return ((len < INT32_MAX) && ((len == 0) || (buf != nullptr))); +} + +const char* ssl_error_to_str(int ssl_error) noexcept { + // From https://www.openssl.org/docs/manmaster/man3/SSL_get_error.html + // Our code paths shouldn't trigger most of these, but included for completeness + switch (ssl_error) { + case SSL_ERROR_NONE: + return "SSL_ERROR_NONE"; + case SSL_ERROR_ZERO_RETURN: + return "SSL_ERROR_ZERO_RETURN"; + case SSL_ERROR_WANT_READ: + return "SSL_ERROR_WANT_READ"; + case SSL_ERROR_WANT_WRITE: + return "SSL_ERROR_WANT_WRITE"; + case SSL_ERROR_WANT_CONNECT: + return "SSL_ERROR_WANT_CONNECT"; + case SSL_ERROR_WANT_ACCEPT: + return "SSL_ERROR_WANT_ACCEPT"; + case SSL_ERROR_WANT_X509_LOOKUP: + return "SSL_ERROR_WANT_X509_LOOKUP"; +#if (OPENSSL_VERSION_NUMBER >= 0x10100000L) + case SSL_ERROR_WANT_ASYNC: + return "SSL_ERROR_WANT_ASYNC"; + case SSL_ERROR_WANT_ASYNC_JOB: + return "SSL_ERROR_WANT_ASYNC_JOB"; +#endif +#if (OPENSSL_VERSION_NUMBER >= 0x10101000L) + case SSL_ERROR_WANT_CLIENT_HELLO_CB: + return "SSL_ERROR_WANT_CLIENT_HELLO_CB"; +#endif + case SSL_ERROR_SYSCALL: + return "SSL_ERROR_SYSCALL"; + case SSL_ERROR_SSL: + return "SSL_ERROR_SSL"; + default: + return "Unknown SSL error code"; + } +} + +HandshakeResult handshake_consumed_bytes_and_needs_more_peer_data(size_t consumed) noexcept { + return {consumed, 0, HandshakeResult::State::NeedsMorePeerData}; +} + +HandshakeResult handshake_produced_bytes_and_needs_more_peer_data(size_t produced) noexcept { + return {0, produced, HandshakeResult::State::NeedsMorePeerData}; +} + +HandshakeResult handshake_consumed_bytes_and_is_complete(size_t consumed) noexcept { + return {consumed, 0, HandshakeResult::State::Done}; +} + +HandshakeResult handshaked_bytes(size_t consumed, size_t produced, HandshakeResult::State state) noexcept { + return {consumed, produced, state}; +} + +HandshakeResult handshake_completed() noexcept { + return {0, 0, HandshakeResult::State::Done}; +} + +HandshakeResult handshake_failed() noexcept { + return {0, 0, HandshakeResult::State::Failed}; +} + +EncodeResult encode_failed() noexcept { + return {0, 0, true}; +} + +EncodeResult encoded_bytes(size_t consumed, size_t produced) noexcept { + return {consumed, produced, false}; +} + +DecodeResult decode_failed() noexcept { + return {0, 0, DecodeResult::State::Failed}; +} + +DecodeResult decoded_frames_with_plaintext_bytes(size_t produced_bytes) noexcept { + return {0, produced_bytes, DecodeResult::State::OK}; +} + +DecodeResult decode_needs_more_peer_data() noexcept { + return {0, 0, DecodeResult::State::NeedsMorePeerData}; +} + +DecodeResult decoded_bytes(size_t consumed, size_t produced, DecodeResult::State state) noexcept { + return {consumed, produced, state}; +} + +BioPtr new_tls_frame_memory_bio() { + BioPtr bio(::BIO_new(BIO_s_mem())); + if (!bio) { + throw CryptoException("IO_new(BIO_s_mem()) failed; out of memory?"); + } + BIO_set_write_buf_size(bio.get(), 0); // 0 ==> default max frame size + return bio; +} + +} // anon ns + +OpenSslCryptoCodecImpl::OpenSslCryptoCodecImpl(::SSL_CTX& ctx, Mode mode) + : _ssl(::SSL_new(&ctx)), + _mode(mode) +{ + if (!_ssl) { + throw CryptoException("Failed to create new SSL from SSL_CTX"); + } + /* + * We use two separate memory BIOs rather than a BIO pair for writing and + * reading ciphertext, respectively. This is because it _seems_ quite + * a bit more straight forward to implement a full duplex API with two + * separate BIOs, but there is little available documentation as to the + * 'hows' and 'whys' around this. + * There are claims from core OpenSSL devs[0] that BIO pairs are more efficient, + * so we may reconsider the current approach (or just use the "OpenSSL controls + * the file descriptor" yolo approach for simplicity, assuming they do optimal + * stuff internally). + * + * Our BIOs are used as follows: + * + * Handshakes may use both BIOs opaquely: + * + * handshake() : SSL_do_handshake() --(_output_bio ciphertext)--> BIO_read --> [peer] + * : SSL_do_handshake() <--(_input_bio ciphertext)-- BIO_write <-- [peer] + * + * Once handshaking is complete, the input BIO is only used for decodes and the output + * BIO is only used for encodes. We explicitly disallow TLS renegotiation, both for + * the sake of simplicity and for added security (renegotiation is a bit of a rat's nest). + * + * encode() : SSL_write(plaintext) --(_output_bio ciphertext)--> BIO_read --> [peer] + * decode() : SSL_read(plaintext) <--(_input_bio ciphertext)-- BIO_write <-- [peer] + * + * To avoid blowing the sizes of BIOs out of the water, we do our best to encode and decode + * on a per-TLS frame granularity (16K) maximum. + */ + BioPtr tmp_input_bio = new_tls_frame_memory_bio(); + BioPtr tmp_output_bio = new_tls_frame_memory_bio(); + // Connect BIOs used internally by OpenSSL. This transfers ownership. No return value to check. + // TODO replace with explicit SSL_set0_rbio/SSL_set0_wbio on OpenSSL >= v1.1 + ::SSL_set_bio(_ssl.get(), tmp_input_bio.get(), tmp_output_bio.get()); + _input_bio = tmp_input_bio.release(); + _output_bio = tmp_output_bio.release(); + if (_mode == Mode::Client) { + ::SSL_set_connect_state(_ssl.get()); + } else { + ::SSL_set_accept_state(_ssl.get()); + } +} + +// TODO remove spammy logging once code is stable + +// Produces bytes previously written to _output_bio by SSL_do_handshake or SSL_write +int OpenSslCryptoCodecImpl::drain_outgoing_network_bytes_if_any( + char *to_peer, size_t to_peer_buf_size) noexcept { + int out_pending = BIO_pending(_output_bio); + if (out_pending > 0) { + int copied = ::BIO_read(_output_bio, to_peer, static_cast<int>(to_peer_buf_size)); + // TODO BIO_should_retry here? Semantics are unclear, especially for memory BIOs. + LOG(spam, "BIO_read copied out %d bytes of ciphertext from _output_bio", copied); + if (copied < 0) { + LOG(error, "Memory BIO_read() failed with BIO_pending() > 0"); + } + return copied; + } + return out_pending; +} + +HandshakeResult OpenSslCryptoCodecImpl::handshake(const char* from_peer, size_t from_peer_buf_size, + char* to_peer, size_t to_peer_buf_size) noexcept { + LOG_ASSERT(verify_buf(from_peer, from_peer_buf_size) && verify_buf(to_peer, to_peer_buf_size)); + + if (SSL_is_init_finished(_ssl.get())) { + return handshake_completed(); + } + // Still ciphertext data left? If so, get rid of it before we start a new operation + // that wants to fill the output BIO. + int produced = drain_outgoing_network_bytes_if_any(to_peer, to_peer_buf_size); + if (produced > 0) { + // Handshake isn't complete yet and we've got stuff to send. Need to continue handshake + // once more data is available from the peer. + return handshake_produced_bytes_and_needs_more_peer_data(static_cast<size_t>(produced)); + } else if (produced < 0) { + return handshake_failed(); + } + const auto consume_res = do_handshake_and_consume_peer_input_bytes(from_peer, from_peer_buf_size); + LOG_ASSERT(consume_res.bytes_produced == 0); + if (consume_res.failed()) { + return consume_res; + } + // SSL_do_handshake() might have produced more data to send. Note: handshake may + // be complete at this point. + produced = drain_outgoing_network_bytes_if_any(to_peer, to_peer_buf_size); + if (produced < 0) { + return handshake_failed(); + } + return handshaked_bytes(consume_res.bytes_consumed, static_cast<size_t>(produced), consume_res.state); +} + +HandshakeResult OpenSslCryptoCodecImpl::do_handshake_and_consume_peer_input_bytes( + const char *from_peer, size_t from_peer_buf_size) noexcept { + // Feed the SSL session input in frame-sized chunks between each call to SSL_do_handshake(). + // This is primarily to ensure we don't shove unbounded amounts of data into the BIO + // in the case that someone naughty is sending us tons of garbage over the socket. + size_t consumed_total = 0; + while (true) { + // Assumption: SSL_do_handshake will place all required outgoing handshake + // data in the output memory BIO without requiring WANT_WRITE. Freestanding + // memory BIOs are _supposedly_ auto-resizing, so this should work transparently. + // At the very least, if this is not the case we'll auto-fail the connection + // and quickly find out..! + // TODO test multi-frame sized handshake + // TODO should we invoke ::ERR_clear_error() prior? + int ssl_result = ::SSL_do_handshake(_ssl.get()); + ssl_result = ::SSL_get_error(_ssl.get(), ssl_result); + + if (ssl_result == SSL_ERROR_WANT_READ) { + LOG(spam, "SSL_do_handshake() returned SSL_ERROR_WANT_READ"); + if (from_peer_buf_size - consumed_total > 0) { + int consumed = ::BIO_write(_input_bio, from_peer + consumed_total, + static_cast<int>(std::min(MaximumTlsFrameSize, from_peer_buf_size - consumed_total))); + LOG(spam, "BIO_write copied in %d bytes of ciphertext to _input_bio", consumed); + if (consumed < 0) { + LOG(error, "Memory BIO_write() returned %d", consumed); // TODO BIO_need_retry? + return handshake_failed(); + } + consumed_total += consumed; // TODO protect against consumed == 0? + continue; + } else { + return handshake_consumed_bytes_and_needs_more_peer_data(consumed_total); + } + } else if (ssl_result == SSL_ERROR_NONE) { + // At this point SSL_do_handshake has stated it does not need any more peer data, i.e. + // the handshake is complete. + if (!SSL_is_init_finished(_ssl.get())) { + LOG(error, "SSL handshake is not completed even though no more peer data is requested"); + return handshake_failed(); + } + return handshake_consumed_bytes_and_is_complete(consumed_total); + } else { + LOG(error, "SSL_do_handshake() returned unexpected error: %s", ssl_error_to_str(ssl_result)); + return handshake_failed(); + } + }; +} + +EncodeResult OpenSslCryptoCodecImpl::encode(const char* plaintext, size_t plaintext_size, + char* ciphertext, size_t ciphertext_size) noexcept { + LOG_ASSERT(verify_buf(plaintext, plaintext_size) && verify_buf(ciphertext, ciphertext_size)); + + if (!SSL_is_init_finished(_ssl.get())) { + LOG(error, "OpenSslCryptoCodecImpl::encode() called before handshake completed"); + return encode_failed(); + } + size_t bytes_consumed = 0; + if (plaintext_size != 0) { + int to_consume = static_cast<int>(std::min(plaintext_size, MaximumFramePlaintextSize)); + // SSL_write encodes plaintext to ciphertext and writes to _output_bio + int consumed = ::SSL_write(_ssl.get(), plaintext, to_consume); + LOG(spam, "After SSL_write() -> %d, _input_bio pending=%d, _output_bio pending=%d", + consumed, BIO_pending(_input_bio), BIO_pending(_output_bio)); + if (consumed < 0) { + int ssl_error = ::SSL_get_error(_ssl.get(), consumed); + LOG(error, "SSL_write() failed to write frame, got error %s", ssl_error_to_str(ssl_error)); + // TODO explicitly detect and log TLS renegotiation error (SSL_ERROR_WANT_READ)? + return encode_failed(); + } else if (consumed != to_consume) { + LOG(error, "SSL_write() returned OK but did not consume all requested plaintext"); + return encode_failed(); + } + bytes_consumed = static_cast<size_t>(consumed); + } + + int produced = drain_outgoing_network_bytes_if_any(ciphertext, ciphertext_size); + if (produced < 0) { + return encode_failed(); + } + if (BIO_pending(_output_bio) != 0) { + LOG(error, "Residual data left in output BIO on encode(); provided buffer is too small"); + return encode_failed(); + } + return encoded_bytes(bytes_consumed, static_cast<size_t>(produced)); +} +DecodeResult OpenSslCryptoCodecImpl::decode(const char* ciphertext, size_t ciphertext_size, + char* plaintext, size_t plaintext_size) noexcept { + LOG_ASSERT(verify_buf(ciphertext, ciphertext_size) && verify_buf(plaintext, plaintext_size)); + + if (!SSL_is_init_finished(_ssl.get())) { + LOG(error, "OpenSslCryptoCodecImpl::decode() called before handshake completed"); + return decode_failed(); + } + auto produce_res = drain_and_produce_plaintext_from_ssl(plaintext, static_cast<int>(plaintext_size)); + if ((produce_res.bytes_produced > 0) || produce_res.failed()) { + return produce_res; // TODO gRPC [1] handles this differently... allows fallthrough + } + int consumed = consume_peer_input_bytes(ciphertext, ciphertext_size); + if (consumed < 0) { + return decode_failed(); + } + produce_res = drain_and_produce_plaintext_from_ssl(plaintext, static_cast<int>(plaintext_size)); + return decoded_bytes(static_cast<size_t>(consumed), produce_res.bytes_produced, produce_res.state); +} + +DecodeResult OpenSslCryptoCodecImpl::drain_and_produce_plaintext_from_ssl( + char* plaintext, size_t plaintext_size) noexcept { + // SSL_read() is named a bit confusingly. We read _from_ the SSL-internal state + // via the input BIO _into_ to the receiving plaintext buffer. + // This may consume the entire, parts of, or none of the input BIO's data, + // depending on how much TLS frame data is available and its size relative + // to the receiving plaintext buffer. + int produced = ::SSL_read(_ssl.get(), plaintext, static_cast<int>(plaintext_size)); + LOG(spam, "After SSL_read() -> %d, _input_bio pending=%d, _output_bio pending=%d", + produced, BIO_pending(_input_bio), BIO_pending(_output_bio)); + if (produced > 0) { + // At least 1 frame decoded successfully. + return decoded_frames_with_plaintext_bytes(static_cast<size_t>(produced)); + } else { + int ssl_error = ::SSL_get_error(_ssl.get(), produced); + switch (ssl_error) { + case SSL_ERROR_WANT_READ: + // SSL_read() was not able to decode a full frame with the ciphertext that + // we've fed it thus far; caller must feed it some and then try again. + LOG(spam, "SSL_read() returned SSL_ERROR_WANT_READ, must get more ciphertext"); + return decode_needs_more_peer_data(); + default: + LOG(error, "SSL_read() returned unexpected error: %s", ssl_error_to_str(ssl_error)); + return decode_failed(); + } + } +} + +int OpenSslCryptoCodecImpl::consume_peer_input_bytes( + const char* ciphertext, size_t ciphertext_size) noexcept { + // TODO BIO_need_retry on failure? Can this even happen for memory BIOs? + int consumed = ::BIO_write(_input_bio, ciphertext, static_cast<int>(std::min(MaximumTlsFrameSize, ciphertext_size))); + LOG(spam, "BIO_write copied in %d bytes of ciphertext to _input_bio", consumed); + if (consumed < 0) { + LOG(error, "Memory BIO_write() returned %d", consumed); + } + return consumed; +} + +} + +// External references: +// [0] http://openssl.6102.n7.nabble.com/nonblocking-implementation-question-tp1728p1732.html +// [1] https://github.com/grpc/grpc/blob/master/src/core/tsi/ssl_transport_security.cc diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h new file mode 100644 index 00000000000..44ca8859596 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h @@ -0,0 +1,76 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "openssl_typedefs.h" +#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <vespa/vespalib/net/tls/crypto_codec.h> +#include <memory> + +namespace vespalib::net::tls { class TlsContext; } + +namespace vespalib::net::tls::impl { + +/* + * Frame-level OpenSSL-backed TLSv1.2 crypto codec implementation. + * + * Currently has sub-optimal buffer management, and is mostly intended + * as a starting point. + * + * NOT thread safe per instance, but independent instances may be + * used by different threads safely. + */ +class OpenSslCryptoCodecImpl : public CryptoCodec { + SslPtr _ssl; + ::BIO* _input_bio; // Owned by _ssl + ::BIO* _output_bio; // Owned by _ssl + Mode _mode; +public: + OpenSslCryptoCodecImpl(::SSL_CTX& ctx, Mode mode); + + /* + * From RFC 8449 (Record Size Limit Extension for TLS), section 1: + * "TLS versions 1.2 [RFC5246] and earlier permit senders to + * generate records 16384 octets in size, plus any expansion + * from compression and protection up to 2048 octets (though + * typically this expansion is only 16 octets). TLS 1.3 reduces + * the allowance for expansion to 256 octets." + * + * We're on TLSv1.2, so make room for the worst case. + */ + static constexpr size_t MaximumTlsFrameSize = 16384 + 2048; + static constexpr size_t MaximumFramePlaintextSize = 16384; + + size_t min_encode_buffer_size() const noexcept override { + return MaximumTlsFrameSize; + } + size_t min_decode_buffer_size() const noexcept override { + return MaximumFramePlaintextSize; + } + + HandshakeResult handshake(const char* from_peer, size_t from_peer_buf_size, + char* to_peer, size_t to_peer_buf_size) noexcept override; + + EncodeResult encode(const char* plaintext, size_t plaintext_size, + char* ciphertext, size_t ciphertext_size) noexcept override; + DecodeResult decode(const char* ciphertext, size_t ciphertext_size, + char* plaintext, size_t plaintext_size) noexcept override; +private: + /* + * Returns + * n > 0 if n bytes written to `to_peer`. Always <= to_peer_buf_size + * n == 0 if no bytes pending in output BIO + * n < 0 on error + */ + int drain_outgoing_network_bytes_if_any(char *to_peer, size_t to_peer_buf_size) noexcept; + /* + * Returns + * n > 0 if n bytes written to `ciphertext`. Always <= ciphertext_size + * n == 0 if no bytes pending in input BIO + * n < 0 on error + */ + int consume_peer_input_bytes(const char* ciphertext, size_t ciphertext_size) noexcept; + HandshakeResult do_handshake_and_consume_peer_input_bytes(const char *from_peer, size_t from_peer_buf_size) noexcept; + DecodeResult drain_and_produce_plaintext_from_ssl(char* plaintext, size_t plaintext_size) noexcept; +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.cpp b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.cpp new file mode 100644 index 00000000000..27250dd43fc --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.cpp @@ -0,0 +1,269 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "openssl_typedefs.h" +#include "openssl_tls_context_impl.h" +#include <vespa/vespalib/net/tls/crypto_exception.h> +#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <mutex> +#include <vector> +#include <memory> +#include <stdexcept> +#include <openssl/ssl.h> +#include <openssl/crypto.h> +#include <openssl/err.h> +#include <openssl/pem.h> + +#include <vespa/log/log.h> +LOG_SETUP(".vespalib.net.tls.openssl_tls_context_impl"); + +#if (OPENSSL_VERSION_NUMBER < 0x10000000L) +// < 1.0 requires explicit thread ID callback support. +# error "Provided OpenSSL version is too darn old, need at least 1.0" +#endif + +namespace vespalib::net::tls::impl { + +namespace { + +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) + +std::vector<std::unique_ptr<std::mutex>> _g_mutexes; + +// Some works on OpenSSL legacy locking: OpenSSL does not implement locking +// itself internally, deferring to user code callbacks that Do The Needful(tm). +// The `n` parameter refers to the nth mutex, which is always < CRYPTO_num_locks(). +void openssl_locking_cb(int mode, int n, [[maybe_unused]] const char *file, [[maybe_unused]] int line) { + if (mode & CRYPTO_LOCK) { + _g_mutexes[n]->lock(); + } else { + _g_mutexes[n]->unlock(); + } +} + +#endif + +struct OpenSslLibraryResources { + OpenSslLibraryResources(); + ~OpenSslLibraryResources(); +}; + +OpenSslLibraryResources::OpenSslLibraryResources() { + // Other implementations (Asio, gRPC) disagree on whether main library init + // itself should take place on >= v1.1. We always do it to be on the safe side..! + ::SSL_library_init(); + ::SSL_load_error_strings(); + ::OpenSSL_add_all_algorithms(); + // Luckily, the mutex callback madness is not present on >= v1.1 +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) + // Since the init path should happen only once globally, but multiple libraries + // may use OpenSSL, make sure we don't step on any toes if locking callbacks are + // already set up. + if (!::CRYPTO_get_locking_callback()) { + const int num_locks = ::CRYPTO_num_locks(); + LOG_ASSERT(num_locks > 0); + _g_mutexes.reserve(num_locks); + for (int i = 0; i < num_locks; ++i) { + _g_mutexes.emplace_back(std::make_unique<std::mutex>()); + } + ::CRYPTO_set_locking_callback(openssl_locking_cb); + } +#endif +} + +OpenSslLibraryResources::~OpenSslLibraryResources() { +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) + if (::CRYPTO_get_locking_callback() == openssl_locking_cb) { + ::CRYPTO_set_locking_callback(nullptr); + } +#endif + ERR_free_strings(); + EVP_cleanup(); + CRYPTO_cleanup_all_ex_data(); +} + +// TODO make global init instead..? +void ensure_openssl_initialized_once() { + static OpenSslLibraryResources openssl_resources; + (void) openssl_resources; +} + +BioPtr bio_from_string(vespalib::stringref str) { + LOG_ASSERT(str.size() <= INT_MAX); +#if (OPENSSL_VERSION_NUMBER >= 0x10002000L) + BioPtr bio(::BIO_new_mem_buf(str.data(), static_cast<int>(str.size()))); +#else + BioPtr bio(::BIO_new_mem_buf(const_cast<char*>(str.data()), static_cast<int>(str.size()))); +#endif + if (!bio) { + throw CryptoException("BIO_new_mem_buf"); + } + return bio; +} + +// Several OpenSSL functions take a magical user passphrase argument with +// potentially horrible default behavior for password protected input. +// +// From OpenSSL docs (https://www.openssl.org/docs/man1.1.0/crypto/PEM_read_bio_PrivateKey.html): +// +// "If the cb parameters is set to NULL and the u parameter is not NULL +// then the u parameter is interpreted as a null terminated string to use +// as the passphrase. If both cb and u are NULL then the default callback +// routine is used which will typically prompt for the passphrase on the +// current terminal with echoing turned off." +// +// Neat! +// +// Bonus points for being non-const as well. +constexpr inline void *empty_passphrase() { + return const_cast<void *>(static_cast<const void *>("")); +} + +// Attempt to read a PEM encoded (trusted) certificate from the given BIO. +// BIO might contain further certificates if function returns non-nullptr. +// Returns nullptr if no certificate could be loaded. This is usually an error, +// as this should be the first certificate in the chain. +X509Ptr read_trusted_x509_from_bio(::BIO& bio) { + // "_AUX" means the certificate is trusted. Why they couldn't name this function + // something with "trusted" instead is left as an exercise to the reader. + return X509Ptr(::PEM_read_bio_X509_AUX(&bio, nullptr, nullptr, empty_passphrase())); +} + +// Attempt to read a PEM encoded certificate from the given BIO. +// BIO might contain further certificates if function returns non-nullptr. +// Returns nullptr if no certificate could be loaded. This usually implies +// that there are no more certificates left in the chain. +X509Ptr read_untrusted_x509_from_bio(::BIO& bio) { + return X509Ptr(::PEM_read_bio_X509(&bio, nullptr, nullptr, empty_passphrase())); +} + +SslCtxPtr new_tls_ctx_with_auto_init() { + ensure_openssl_initialized_once(); +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) + return SslCtxPtr(::SSL_CTX_new(::TLSv1_2_method())); +#else + SslCtxPtr ctx(::SSL_CTX_new(::TLS_method())); + if (!::SSL_CTX_set_min_proto_version(ctx.get(), TLS1_2_VERSION)) { + throw CryptoException("SSL_CTX_set_min_proto_version"); + } + return ctx; +#endif +} + +} // anon ns + +OpenSslTlsContextImpl::OpenSslTlsContextImpl(const TransportSecurityOptions& ts_opts) + : _ctx(new_tls_ctx_with_auto_init()) +{ + if (!_ctx) { + throw CryptoException("Failed to create new TLS context"); + } + add_certificate_authorities(ts_opts.ca_certs_pem()); + add_certificate_chain(ts_opts.cert_chain_pem()); + use_private_key(ts_opts.private_key_pem()); + verify_private_key(); + enable_ephemeral_key_exchange(); + disable_compression(); + enforce_peer_certificate_verification(); + // TODO set accepted cipher suites! + // TODO `--> If not set in options, use Modern spec from https://wiki.mozilla.org/Security/Server_Side_TLS +} + +OpenSslTlsContextImpl::~OpenSslTlsContextImpl() = default; + +void OpenSslTlsContextImpl::add_certificate_authorities(vespalib::stringref ca_pem) { + // TODO support empty CA set...? Ever useful? + auto bio = bio_from_string(ca_pem); + ::X509_STORE* cert_store = ::SSL_CTX_get_cert_store(_ctx.get()); // Internal pointer, not owned by us. + while (true) { + auto ca_cert = read_untrusted_x509_from_bio(*bio); + if (!ca_cert) { + break; + } + if (::X509_STORE_add_cert(cert_store, ca_cert.get()) != 1) { // Does _not_ take ownership + throw CryptoException("X509_STORE_add_cert"); + } + } +} + +void OpenSslTlsContextImpl::add_certificate_chain(vespalib::stringref chain_pem) { + ::ERR_clear_error(); + auto bio = bio_from_string(chain_pem); + // First certificate in the chain is the node's own (trusted) certificate. + auto own_cert = read_trusted_x509_from_bio(*bio); + if (!own_cert) { + throw CryptoException("No X509 certificates could be found in provided chain"); + } + // Ownership of certificate is _not_ transferred, OpenSSL makes internal copy. + // This is not well documented, but is mentioned by other impls. + if (::SSL_CTX_use_certificate(_ctx.get(), own_cert.get()) != 1) { + throw CryptoException("SSL_CTX_use_certificate"); + } + // After the node's own certificate comes any intermediate CA-provided certificates. + while (true) { + auto ca_cert = read_untrusted_x509_from_bio(*bio); + if (!ca_cert) { + // No more certificates in chain, hooray! + ::ERR_clear_error(); + break; + } + // Ownership of certificate _is_ transferred here! + if (!::SSL_CTX_add_extra_chain_cert(_ctx.get(), ca_cert.release())) { + throw CryptoException("SSL_CTX_add_extra_chain_cert"); + } + } +} + +void OpenSslTlsContextImpl::use_private_key(vespalib::stringref key_pem) { + auto bio = bio_from_string(key_pem); + EvpPkeyPtr key(::PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, empty_passphrase())); + if (!key) { + throw CryptoException("Failed to read PEM private key data"); + } + // Ownership _not_ taken. + if (::SSL_CTX_use_PrivateKey(_ctx.get(), key.get()) != 1) { + throw CryptoException("SSL_CTX_use_PrivateKey"); + } +} + +void OpenSslTlsContextImpl::verify_private_key() { + if (::SSL_CTX_check_private_key(_ctx.get()) != 1) { + throw CryptoException("SSL_CTX_check_private_key failed; mismatch between public and private key?"); + } +} + +void OpenSslTlsContextImpl::enable_ephemeral_key_exchange() { +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) +# if (OPENSSL_VERSION_NUMBER >= 0x10002000L) + // Always enabled by default on higher versions. + // Auto curve selection is preferred over using SSL_CTX_set_ecdh_tmp + if (!::SSL_CTX_set_ecdh_auto(_ctx.get(), 1)) { + throw CryptoException("SSL_CTX_set_ecdh_auto"); + } + // New ECDH key per connection. + ::SSL_CTX_set_options(_ctx.get(), SSL_OP_SINGLE_ECDH_USE); +# else + // Set explicit P-256 curve used for ECDH purposes. + EcKeyPtr ec_curve(::EC_KEY_new_by_curve_name(NID_X9_62_prime256v1)); + if (!ec_curve) { + throw CryptoException("EC_KEY_new_by_curve_name(NID_X9_62_prime256v1)"); + } + if (!::SSL_CTX_set_tmp_ecdh(_ctx.get(), ec_curve.get())) { + throw CryptoException("SSL_CTX_set_tmp_ecdh"); + } +# endif +#endif +} + +void OpenSslTlsContextImpl::disable_compression() { + // TLS stream compression is vulnerable to a host of chosen plaintext + // attacks (CRIME, BREACH etc), so disable it. + ::SSL_CTX_set_options(_ctx.get(), SSL_OP_NO_COMPRESSION); +} + +void OpenSslTlsContextImpl::enforce_peer_certificate_verification() { + // We require full mutual certificate verification. No way to configure + // out of this, at least not for the time being. + // TODO verification callback for custom CN/SAN etc checks. + SSL_CTX_set_verify(_ctx.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h new file mode 100644 index 00000000000..72f9f3b570d --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h @@ -0,0 +1,29 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "openssl_typedefs.h" +#include <vespa/vespalib/net/tls/tls_context.h> +#include <vespa/vespalib/stllike/string.h> + +namespace vespalib::net::tls::impl { + +class OpenSslTlsContextImpl : public TlsContext { + SslCtxPtr _ctx; +public: + explicit OpenSslTlsContextImpl(const TransportSecurityOptions&); + ~OpenSslTlsContextImpl() override; + + ::SSL_CTX* native_context() const noexcept { return _ctx.get(); } +private: + // Note: single use per instance; does _not_ clear existing chain! + void add_certificate_authorities(stringref ca_pem); + void add_certificate_chain(stringref chain_pem); + void use_private_key(stringref key_pem); + void verify_private_key(); + // Enable use of ephemeral key exchange (ECDHE), allowing forward secrecy. + void enable_ephemeral_key_exchange(); + void disable_compression(); + void enforce_peer_certificate_verification(); +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_typedefs.h b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_typedefs.h new file mode 100644 index 00000000000..afafe556338 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_typedefs.h @@ -0,0 +1,53 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <memory> +#include <openssl/ssl.h> +#include <openssl/crypto.h> +#include <openssl/x509.h> + +namespace vespalib::net::tls::impl { + +struct BioDeleter { + void operator()(::BIO* bio) const noexcept { + ::BIO_free(bio); + } +}; +using BioPtr = std::unique_ptr<::BIO, BioDeleter>; + +struct SslDeleter { + void operator()(::SSL* ssl) const noexcept { + ::SSL_free(ssl); + } +}; +using SslPtr = std::unique_ptr<::SSL, SslDeleter>; + +struct SslCtxDeleter { + void operator()(::SSL_CTX* ssl) const noexcept { + ::SSL_CTX_free(ssl); + } +}; +using SslCtxPtr = std::unique_ptr<::SSL_CTX, SslCtxDeleter>; + +struct X509Deleter { + void operator()(::X509* cert) const noexcept { + ::X509_free(cert); + } +}; +using X509Ptr = std::unique_ptr<::X509, X509Deleter>; + +struct EvpPkeyDeleter { + void operator()(::EVP_PKEY* pkey) const noexcept { + ::EVP_PKEY_free(pkey); + } +}; +using EvpPkeyPtr = std::unique_ptr<::EVP_PKEY, EvpPkeyDeleter>; + +struct EcKeyDeleter { + void operator()(::EC_KEY* ec_key) const noexcept { + ::EC_KEY_free(ec_key); + } +}; +using EcKeyPtr = std::unique_ptr<::EC_KEY, EcKeyDeleter>; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/tls_context.cpp b/vespalib/src/vespa/vespalib/net/tls/tls_context.cpp new file mode 100644 index 00000000000..467838975e7 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/tls_context.cpp @@ -0,0 +1,11 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "tls_context.h" +#include <vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h> + +namespace vespalib::net::tls { + +std::unique_ptr<TlsContext> TlsContext::create_default_context(const TransportSecurityOptions& opts) { + return std::make_unique<impl::OpenSslTlsContextImpl>(opts); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/tls_context.h b/vespalib/src/vespa/vespalib/net/tls/tls_context.h new file mode 100644 index 00000000000..7292f43f88c --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/tls_context.h @@ -0,0 +1,16 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <memory> + +namespace vespalib::net::tls { + +class TransportSecurityOptions; + +struct TlsContext { + virtual ~TlsContext() = default; + + static std::unique_ptr<TlsContext> create_default_context(const TransportSecurityOptions&); +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.cpp b/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.cpp new file mode 100644 index 00000000000..72d9eacf37c --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.cpp @@ -0,0 +1,22 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "tls_crypto_engine.h" +#include "crypto_codec.h" +#include "crypto_codec_adapter.h" + +namespace vespalib { + +TlsCryptoEngine::TlsCryptoEngine(net::tls::TransportSecurityOptions tls_opts) + : _tls_ctx(net::tls::TlsContext::create_default_context(tls_opts)) +{ +} + +CryptoSocket::UP +TlsCryptoEngine::create_crypto_socket(SocketHandle socket, bool is_server) +{ + auto mode = is_server ? net::tls::CryptoCodec::Mode::Server : net::tls::CryptoCodec::Mode::Client; + auto codec = net::tls::CryptoCodec::create_default_codec(*_tls_ctx, mode); + return std::make_unique<net::tls::CryptoCodecAdapter>(std::move(socket), std::move(codec)); +} + +} // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.h b/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.h new file mode 100644 index 00000000000..58fda2b3b21 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/tls_crypto_engine.h @@ -0,0 +1,23 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/net/crypto_engine.h> +#include "transport_security_options.h" +#include "tls_context.h" + +namespace vespalib { + +/** + * Crypto engine implementing TLS. + **/ +class TlsCryptoEngine : public CryptoEngine +{ +private: + std::unique_ptr<net::tls::TlsContext> _tls_ctx; +public: + TlsCryptoEngine(net::tls::TransportSecurityOptions tls_opts); + CryptoSocket::UP create_crypto_socket(SocketHandle socket, bool is_server) override; +}; + +} // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/net/tls/transport_security_options.cpp b/vespalib/src/vespa/vespalib/net/tls/transport_security_options.cpp new file mode 100644 index 00000000000..4e39fe4d7fa --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/transport_security_options.cpp @@ -0,0 +1,12 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "transport_security_options.h" +#include <openssl/crypto.h> + +namespace vespalib::net::tls { + +TransportSecurityOptions::~TransportSecurityOptions() { + OPENSSL_cleanse(&_private_key_pem[0], _private_key_pem.size()); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/transport_security_options.h b/vespalib/src/vespa/vespalib/net/tls/transport_security_options.h new file mode 100644 index 00000000000..0a228388791 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/transport_security_options.h @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/stllike/string.h> + +namespace vespalib::net::tls { + +class TransportSecurityOptions { + vespalib::string _ca_certs_pem; + vespalib::string _cert_chain_pem; + vespalib::string _private_key_pem; +public: + TransportSecurityOptions() = default; + + TransportSecurityOptions(vespalib::string ca_certs_pem, + vespalib::string cert_chain_pem, + vespalib::string private_key_pem) + : _ca_certs_pem(std::move(ca_certs_pem)), + _cert_chain_pem(std::move(cert_chain_pem)), + _private_key_pem(std::move(private_key_pem)) + {} + ~TransportSecurityOptions(); + + const vespalib::string& ca_certs_pem() const noexcept { return _ca_certs_pem; } + const vespalib::string& cert_chain_pem() const noexcept { return _cert_chain_pem; } + const vespalib::string& private_key_pem() const noexcept { return _private_key_pem; } +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.cpp b/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.cpp new file mode 100644 index 00000000000..05cfc797e51 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.cpp @@ -0,0 +1,102 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "transport_security_options_reading.h" +#include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/io/fileutil.h> +#include <vespa/vespalib/io/mapped_file_input.h> +#include <vespa/vespalib/data/memory_input.h> + +namespace vespalib::net::tls { + +/* + + Proposed JSON format for TLS configuration file: + +{ + "files": { + "private-key": "myhost.key", + "ca-certificates": "my_cas.pem", + "certificates": "certs.pem" + }, + // for later: + "peer-taggers": [ + { + "requirements":[ + { + "field": "SAN" + "must-match": "DNS:foo.bar.baz.*" + } + ], + "tags": ["cluster-peers", "config-server"] // or "roles"? Avoid ambiguities with Athenz concepts + }, + { + "requirements":[ + { "field":"CN", "must-match": "config.blarg.*"} + ], + "tags": ["config-server"] + } + ] +} + + */ + +using namespace slime::convenience; + +namespace { + +constexpr const char* files_field = "files"; +constexpr const char* private_key_field = "private-key"; +constexpr const char* ca_certs_field = "ca-certificates"; +constexpr const char* certs_field = "certificates"; + +void verify_referenced_file_exists(const vespalib::string& file_path) { + if (!fileExists(file_path)) { + throw IllegalArgumentException(make_string("File '%s' referenced by TLS config does not exist", file_path.c_str())); + } +} + +vespalib::string load_file_referenced_by_field(const Cursor& cursor, const char* field) { + auto file_path = cursor[field].asString().make_string(); + if (file_path.empty()) { + throw IllegalArgumentException(make_string("TLS config field '%s' has not been set", field)); + } + verify_referenced_file_exists(file_path); + return File::readAll(file_path); +} + +std::unique_ptr<TransportSecurityOptions> load_from_input(Input& input) { + Slime root; + auto parsed = slime::JsonFormat::decode(input, root); + if (parsed == 0) { + throw IllegalArgumentException("Provided TLS config file is not valid JSON"); + } + auto& files = root[files_field]; + if (files.fields() == 0) { + throw IllegalArgumentException("TLS config root field 'files' is missing or empty"); + } + // Note: we do no look at the _contents_ of the files; this is deferred to the + // TLS context code which actually tries to extract key and certificate material + // from them. + auto ca_certs = load_file_referenced_by_field(files, ca_certs_field); + auto certs = load_file_referenced_by_field(files, certs_field); + auto priv_key = load_file_referenced_by_field(files, private_key_field); + + return std::make_unique<TransportSecurityOptions>(std::move(ca_certs), std::move(certs), std::move(priv_key)); +} + +} // anon ns + +std::unique_ptr<TransportSecurityOptions> read_options_from_json_string(const vespalib::string& json_data) { + MemoryInput file_input(json_data); + return load_from_input(file_input); +} + +std::unique_ptr<TransportSecurityOptions> read_options_from_json_file(const vespalib::string& file_path) { + MappedFileInput file_input(file_path); + if (!file_input.valid()) { + throw IllegalArgumentException(make_string("TLS config file '%s' could not be read", file_path.c_str())); + } + return load_from_input(file_input); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.h b/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.h new file mode 100644 index 00000000000..800b3b5ed0d --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/transport_security_options_reading.h @@ -0,0 +1,20 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "transport_security_options.h" +#include <memory> + +namespace vespalib::net::tls { + +// TODO consider renaming TransportSecurityOptions -> TlsConfig + +/** + * Throws IoException if file_path or any files referenced by it can't be accessed + * Throws IllegalArgumentException if file is not parseable as a valid TLS config file or + * if mandatory JSON fields are missing or incomplete. + */ +std::unique_ptr<TransportSecurityOptions> read_options_from_json_file(const vespalib::string& file_path); +// Same properties as read_options_from_json_file() +std::unique_ptr<TransportSecurityOptions> read_options_from_json_string(const vespalib::string& json_data); + +} diff --git a/vespalib/src/vespa/vespalib/test/CMakeLists.txt b/vespalib/src/vespa/vespalib/test/CMakeLists.txt index 4c2c65e8793..4eb47735ca7 100644 --- a/vespalib/src/vespa/vespalib/test/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/test/CMakeLists.txt @@ -1,5 +1,6 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -vespa_add_library(vespalib_vespalib_test INTERFACE +vespa_add_library(vespalib_vespalib_test OBJECT SOURCES + make_tls_options_for_testing.cpp DEPENDS ) diff --git a/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.cpp b/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.cpp new file mode 100644 index 00000000000..e70914dec2f --- /dev/null +++ b/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.cpp @@ -0,0 +1,77 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "make_tls_options_for_testing.h" + +/* + * Generated with the following commands: + * + * openssl ecparam -name prime256v1 -genkey -out ca.key + * + * openssl req -new -x509 -nodes -key ca.key \ + * -sha256 -out ca.pem \ + * -subj '/C=US/L=LooneyVille/O=ACME/OU=ACME test CA/CN=acme.example.com' \ + * -days 10000 + * + * openssl ecparam -name prime256v1 -genkey -out host.key + * + * openssl req -new -key host.key -out host.csr \ + * -subj '/C=US/L=LooneyVille/O=Wile. E. Coyote, Ltd./CN=wile.example.com' \ + * -sha256 + * + * openssl x509 -req -in host.csr \ + * -CA ca.pem \ + * -CAkey ca.key \ + * -CAcreateserial \ + * -out host.pem \ + * -days 10000 \ + * -sha256 + * + * TODO generate keypairs and certs at test-time to avoid any hard-coding + * There certs are valid until 2046, so that buys us some time..! + */ + +// ca.pem +constexpr const char* ca_pem = R"(-----BEGIN CERTIFICATE----- +MIIBuDCCAV4CCQDpVjQIixTxvDAKBggqhkjOPQQDAjBkMQswCQYDVQQGEwJVUzEU +MBIGA1UEBwwLTG9vbmV5VmlsbGUxDTALBgNVBAoMBEFDTUUxFTATBgNVBAsMDEFD +TUUgdGVzdCBDQTEZMBcGA1UEAwwQYWNtZS5leGFtcGxlLmNvbTAeFw0xODA4MzEx +MDU3NDVaFw00NjAxMTYxMDU3NDVaMGQxCzAJBgNVBAYTAlVTMRQwEgYDVQQHDAtM +b29uZXlWaWxsZTENMAsGA1UECgwEQUNNRTEVMBMGA1UECwwMQUNNRSB0ZXN0IENB +MRkwFwYDVQQDDBBhY21lLmV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0D +AQcDQgAE1L7IzCN5pbyVnBATIHieuxq+hf9kWyn5yfjkXMhD52T5ITz1huq4nbiN +YtRoRP7XmipI60R/uiCHzERcsVz4rDAKBggqhkjOPQQDAgNIADBFAiEA6wmZDBca +y0aJ6ABtjbjx/vlmVDxdkaSZSgO8h2CkvIECIFktCkbZhDFfSvbqUScPOGuwkdGQ +L/EW2Bxp+1BPcYoZ +-----END CERTIFICATE-----)"; + +// host.pem +constexpr const char* cert_pem = R"(-----BEGIN CERTIFICATE----- +MIIBsTCCAVgCCQD6GfDh0ltpsjAKBggqhkjOPQQDAjBkMQswCQYDVQQGEwJVUzEU +MBIGA1UEBwwLTG9vbmV5VmlsbGUxDTALBgNVBAoMBEFDTUUxFTATBgNVBAsMDEFD +TUUgdGVzdCBDQTEZMBcGA1UEAwwQYWNtZS5leGFtcGxlLmNvbTAeFw0xODA4MzEx +MDU3NDVaFw00NjAxMTYxMDU3NDVaMF4xCzAJBgNVBAYTAlVTMRQwEgYDVQQHDAtM +b29uZXlWaWxsZTEeMBwGA1UECgwVV2lsZS4gRS4gQ295b3RlLCBMdGQuMRkwFwYD +VQQDDBB3aWxlLmV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE +e+Y4hxt66em0STviGUj6ZDbxzoLoubXWRml8JDFrEc2S2433KWw2npxYSKVCyo3a +/Vo33V8/H0WgOXioKEZJxDAKBggqhkjOPQQDAgNHADBEAiAN+87hQuGv3z0Ja2BV +b8PHq2vp3BJHjeMuxWu4BFPn0QIgYlvIHikspgGatXRNMZ1gPC0oCccsJFcie+Cw +zL06UPI= +-----END CERTIFICATE-----)"; + +// host.key +constexpr const char* key_pem = R"(-----BEGIN EC PARAMETERS----- +BggqhkjOPQMBBw== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEID6di2PFYn8hPrxPbkFDGkSqF+K8L520In7nx3g0jwzOoAoGCCqGSM49 +AwEHoUQDQgAEe+Y4hxt66em0STviGUj6ZDbxzoLoubXWRml8JDFrEc2S2433KWw2 +npxYSKVCyo3a/Vo33V8/H0WgOXioKEZJxA== +-----END EC PRIVATE KEY-----)"; + +namespace vespalib::test { + +vespalib::net::tls::TransportSecurityOptions make_tls_options_for_testing() { + return vespalib::net::tls::TransportSecurityOptions(ca_pem, cert_pem, key_pem); +} + +} // namespace vespalib::test diff --git a/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.h b/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.h new file mode 100644 index 00000000000..a1f1d5958f9 --- /dev/null +++ b/vespalib/src/vespa/vespalib/test/make_tls_options_for_testing.h @@ -0,0 +1,15 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/net/tls/transport_security_options.h> + +namespace vespalib::test { + +/** + * Make security options allowing you to talk to yourself using + * TLS. This is intended for testing purposes only. + **/ +vespalib::net::tls::TransportSecurityOptions make_tls_options_for_testing(); + +} // namespace vespalib::test diff --git a/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h b/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h index 9679e6379f5..6e8fa368df7 100644 --- a/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h +++ b/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h @@ -23,10 +23,10 @@ namespace thread { class ThreadInit; } // init function when creating an executor to inject a frame with the // given name into the stack of all worker threads. -#define VESPA_THREAD_STACK_TAG(name) \ - int name(Runnable &worker) { \ - worker.run(); \ - return 1; \ +#define VESPA_THREAD_STACK_TAG(name) \ + int name(::vespalib::Runnable &worker) { \ + worker.run(); \ + return 1; \ } /** diff --git a/vespalog/pom.xml b/vespalog/pom.xml index 6443769afbe..7b167ee2c1c 100644 --- a/vespalog/pom.xml +++ b/vespalog/pom.xml @@ -50,6 +50,10 @@ <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-surefire-plugin</artifactId> <configuration> + <forkMode>once</forkMode> + <environmentVariables> + <VESPA_HOME>${project.build.directory}</VESPA_HOME> + </environmentVariables> <redirectTestOutputToFile>${test.hide}</redirectTestOutputToFile> </configuration> </plugin> diff --git a/vespalog/src/main/java/com/yahoo/log/LogFileDb.java b/vespalog/src/main/java/com/yahoo/log/LogFileDb.java new file mode 100644 index 00000000000..d0fa64805bf --- /dev/null +++ b/vespalog/src/main/java/com/yahoo/log/LogFileDb.java @@ -0,0 +1,50 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.log; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.nio.file.StandardOpenOption.*; + +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import static com.yahoo.vespa.defaults.Defaults.getDefaults; + + +/** + * @author arnej + * + * This class takes care of saving meta-data about a log-file, + * ensuring that we can enact policies about log retention. + **/ +public class LogFileDb { + + static final String DBDIR = "var/db/vespa/logfiledb/"; + + private static long dayStamp() { + long s = System.currentTimeMillis() / 1000; + return s / 100000; + } + + private static OutputStream metaFile() throws java.io.IOException { + String fn = getDefaults().underVespaHome(DBDIR + "logfiles." + dayStamp()); + Path path = Paths.get(fn); + return Files.newOutputStream(path, CREATE, APPEND); + } + + public static void nowLoggingTo(String filename) { + if (filename.contains("\n")) { + throw new IllegalArgumentException("Cannot use filename with newline: "+filename); + } + long s = System.currentTimeMillis() / 1000; + String meta = "" + s + " " + filename + "\n"; + byte[] data = meta.getBytes(UTF_8); + try (OutputStream out = metaFile()) { + out.write(data); + } catch (java.io.IOException e) { + System.err.println("Saving meta-data about logfile "+filename+" failed: "+e); + // ignore + } + } +} diff --git a/vespalog/src/test/java/com/yahoo/log/LogFileDbTest.java b/vespalog/src/test/java/com/yahoo/log/LogFileDbTest.java new file mode 100644 index 00000000000..4dd7bd0978c --- /dev/null +++ b/vespalog/src/test/java/com/yahoo/log/LogFileDbTest.java @@ -0,0 +1,29 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.log; + +import java.io.File; +import static com.yahoo.vespa.defaults.Defaults.getDefaults; +import org.junit.Test; + +/** + * @author arnej + */ +public class LogFileDbTest { + + @Test + public void canSave() { + System.err.println("VH: "+System.getenv("VESPA_HOME")); + File dir = new File(getDefaults().underVespaHome(LogFileDb.DBDIR)); + dir.mkdirs(); + if (dir.isDirectory()) { + System.err.println("using directory: "+dir); + new File(getDefaults().underVespaHome("logs/extra")).mkdirs(); + String fn = getDefaults().underVespaHome("logs/extra/foo-bar.log"); + LogFileDb.nowLoggingTo(fn); + fn = getDefaults().underVespaHome("logs/extra/stamped-1.log"); + LogFileDb.nowLoggingTo(fn); + } else { + System.err.println("cannot create directory: "+dir); + } + } +} diff --git a/vespamalloc/src/vespamalloc/util/osmem.cpp b/vespamalloc/src/vespamalloc/util/osmem.cpp index d7d32f4844a..f4fbb376265 100644 --- a/vespamalloc/src/vespamalloc/util/osmem.cpp +++ b/vespamalloc/src/vespamalloc/util/osmem.cpp @@ -10,7 +10,8 @@ namespace vespamalloc { -void * MmapMemory::reserve(size_t & len) +void * +MmapMemory::reserve(size_t & len) { len = 0; const size_t wLen(0x1000); @@ -20,10 +21,11 @@ void * MmapMemory::reserve(size_t & len) (void) test; setStart(wanted); setEnd(getStart()); - return NULL; + return nullptr; } -size_t findInMemInfo(const char * wanted) +size_t +findInMemInfo(const char * wanted) { size_t value(0); char memInfo[8192]; @@ -34,16 +36,17 @@ size_t findInMemInfo(const char * wanted) assert((sz < int(sizeof(memInfo))) && (sz >= 0)); memInfo[sz] = '\0'; const char * found(strstr(memInfo, wanted)); - if (found != NULL) { + if (found != nullptr) { found += strlen(wanted); - value = strtoul(found, NULL, 0); + value = strtoul(found, nullptr, 0); } close(fd); } return value; } -const char * getToken(const char * & s, const char * e) +const char * +getToken(const char * & s, const char * e) { for (; (s < e) && isspace(s[0]); s++) { } const char * c = s; @@ -51,7 +54,8 @@ const char * getToken(const char * & s, const char * e) return c; } -bool verifyHugePagesMount(const char * mount) +bool +verifyHugePagesMount(const char * mount) { const unsigned int HUGETLBFS_MAGIC(0x958458f6); struct statfs64 st; @@ -70,15 +74,17 @@ MmapMemory::MmapMemory(size_t blockSize) : setupHugePages(); } -void MmapMemory::setupFAdvise() +void +MmapMemory::setupFAdvise() { const char * madv = getenv("VESPA_MALLOC_MADVISE_LIMIT"); if (madv) { - _useMAdvLimit = strtoul(madv, NULL, 0); + _useMAdvLimit = strtoul(madv, nullptr, 0); } } -void MmapMemory::setupHugePages() +void +MmapMemory::setupHugePages() { _hugePagesFileName[0] = '\0'; const char * vespaHugePages = getenv("VESPA_MALLOC_HUGEPAGES"); @@ -140,23 +146,29 @@ MmapMemory::~MmapMemory() } } -void * MmapMemory::get(size_t len) +void * +MmapMemory::get(size_t len) { - void * memory(NULL); + void * memory(nullptr); + int prevErrno = errno; memory = getHugePages(len); - if (memory ==NULL) { + if (memory == nullptr) { + errno = prevErrno; // The temporary error should not impact if the end is good. memory = getNormalPages(len); } return memory; } -void * MmapMemory::getHugePages(size_t len) +void * +MmapMemory::getHugePages(size_t len) { - void * memory(NULL); + void * memory(nullptr); if ( ((len & 0x1fffff) == 0) && len) { + int prevErrno = errno; memory = getBasePages(len, MAP_ANON | MAP_PRIVATE | MAP_HUGETLB, -1, 0); - if (memory == NULL) { + if (memory == nullptr) { if (_hugePagesFd >= 0) { + errno = prevErrno; // The temporary error should not impact if the end is good. memory = getBasePages(len, MAP_SHARED, _hugePagesFd, _hugePagesOffset); if (memory) { _hugePagesOffset += len; @@ -167,21 +179,22 @@ void * MmapMemory::getHugePages(size_t len) return memory; } -void * MmapMemory::getNormalPages(size_t len) +void * +MmapMemory::getNormalPages(size_t len) { return getBasePages(len, MAP_ANON | MAP_PRIVATE, -1, 0); } -void * MmapMemory::getBasePages(size_t len, int mmapOpt, int fd, size_t offset) +void * +MmapMemory::getBasePages(size_t len, int mmapOpt, int fd, size_t offset) { char * wanted = reinterpret_cast<char *>(std::max(reinterpret_cast<size_t>(getEnd()), getMinPreferredStartAddress())); - void * mem(NULL); + void * mem(nullptr); for (bool ok(false) ; !ok && (mem != MAP_FAILED); wanted += getBlockAlignment()) { - if (mem != NULL) { + if (mem != nullptr) { int tmp(munmap(mem, len)); assert(tmp == 0); (void) tmp; - mem = NULL; } // no alignment to _blockSize needed? // both 0x10000000000ul*4 and 0x200000 are multiples of the current block size. @@ -189,7 +202,7 @@ void * MmapMemory::getBasePages(size_t len, int mmapOpt, int fd, size_t offset) ok = (mem == wanted); } if (mem != MAP_FAILED) { - if (getStart() == NULL) { + if (getStart() == nullptr) { setStart(mem); // assumes len parameter is always multiple of the current block size. setEnd(static_cast<char *>(mem)+len); @@ -198,10 +211,11 @@ void * MmapMemory::getBasePages(size_t len, int mmapOpt, int fd, size_t offset) } return mem; } - return NULL; + return nullptr; } -bool MmapMemory::release(void * mem, size_t len) +bool +MmapMemory::release(void * mem, size_t len) { int ret(0); if (_useMAdvLimit <= len) { @@ -214,7 +228,8 @@ bool MmapMemory::release(void * mem, size_t len) return true; } -bool MmapMemory::freeTail(void * mem, size_t len) +bool +MmapMemory::freeTail(void * mem, size_t len) { int ret(0); if ((_useMAdvLimit <= len) && (static_cast<char *>(mem) + len) == getEnd()) { @@ -225,7 +240,8 @@ bool MmapMemory::freeTail(void * mem, size_t len) return (ret == 0); } -bool MmapMemory::reclaim(void * mem, size_t len) +bool +MmapMemory::reclaim(void * mem, size_t len) { int ret(0); if (_useMAdvLimit <= len) { diff --git a/vespamalloc/src/vespamalloc/util/osmem.h b/vespamalloc/src/vespamalloc/util/osmem.h index 4ccc2bc112c..2faf3c9b181 100644 --- a/vespamalloc/src/vespamalloc/util/osmem.h +++ b/vespamalloc/src/vespamalloc/util/osmem.h @@ -13,7 +13,7 @@ namespace vespamalloc { class Memory { public: - Memory(size_t blockSize) : _blockSize(std::max(blockSize, size_t(getpagesize()))), _start(NULL), _end(NULL) { } + Memory(size_t blockSize) : _blockSize(std::max(blockSize, size_t(getpagesize()))), _start(nullptr), _end(nullptr) { } virtual ~Memory() { } void * getStart() const { return _start; } void * getEnd() const { return _end; } diff --git a/vsm/src/vespa/vsm/vsm/docsumconfig.cpp b/vsm/src/vespa/vsm/vsm/docsumconfig.cpp index 7df2205bf39..25c13967c49 100644 --- a/vsm/src/vespa/vsm/vsm/docsumconfig.cpp +++ b/vsm/src/vespa/vsm/vsm/docsumconfig.cpp @@ -23,7 +23,8 @@ DynamicDocsumConfig::createFieldWriter(const string & fieldName, const string & fieldWriter.reset(new EmptyDFW()); rc = true; } else if ((overrideName == "attribute") || - ((overrideName == "geopos"))) { + (overrideName == "attributecombiner") || + (overrideName == "geopos")) { rc = true; } else { fieldWriter = search::docsummary::DynamicDocsumConfig::createFieldWriter(fieldName, overrideName, argument, rc); |