diff options
219 files changed, 7266 insertions, 2839 deletions
diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CertificateSigner.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CertificateSigner.java index f188fba5074..8c851ed5489 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CertificateSigner.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CertificateSigner.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.athenz.instanceproviderservice.ca; import com.google.common.collect.ImmutableList; diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CertificateSignerTest.java b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CertificateSignerTest.java index e691da0b2c3..480ff5679fe 100644 --- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CertificateSignerTest.java +++ b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CertificateSignerTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.athenz.instanceproviderservice.ca; import com.yahoo.test.ManualClock; diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CsrSerializedPayloadTest.java b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CsrSerializedPayloadTest.java index b8433856f95..b12ef70b1dc 100644 --- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CsrSerializedPayloadTest.java +++ b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/ca/CsrSerializedPayloadTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.athenz.instanceproviderservice.ca; import com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl.Utils; diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java index 0c12e137e27..7389cf1596d 100644 --- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java +++ b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.athenz.instanceproviderservice.identitydocument; diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java index c68a8805abc..84105c5b551 100644 --- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java +++ b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.athenz.instanceproviderservice.instanceconfirmation; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java index d635fe90ded..515477641a8 100644 --- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java +++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java @@ -35,4 +35,9 @@ public class MockFileRegistry implements FileRegistry { return result; } + @Override + public FileReference addUri(String uri) { + throw new IllegalArgumentException("FileReference addUri(String uri) is not implemented for " + getClass().getCanonicalName()); + } + } diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/PreGeneratedFileRegistry.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/PreGeneratedFileRegistry.java index 0b0b799f47f..ed85b987a3d 100644 --- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/PreGeneratedFileRegistry.java +++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/PreGeneratedFileRegistry.java @@ -74,6 +74,11 @@ public class PreGeneratedFileRegistry implements FileRegistry { } @Override + public FileReference addUri(String uri) { + return new FileReference(path2Hash.get(uri)); + } + + @Override public String fileSourceHost() { return fileSourceHost; } diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/FileRegistry.java b/config-model-api/src/main/java/com/yahoo/config/application/api/FileRegistry.java index 887da2d51c8..15ae4294762 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/FileRegistry.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/FileRegistry.java @@ -2,16 +2,17 @@ package com.yahoo.config.application.api; import java.util.List; -import java.util.Set; import com.yahoo.config.FileReference; + /** * @author tonytv */ public interface FileRegistry { FileReference addFile(String relativePath); + FileReference addUri(String uri); /** * Returns the name of the host which is the source of the files diff --git a/config-model-fat/pom.xml b/config-model-fat/pom.xml index 81fb1b29162..8688f12c199 100644 --- a/config-model-fat/pom.xml +++ b/config-model-fat/pom.xml @@ -28,6 +28,13 @@ <groupId>com.yahoo.vespa</groupId> <artifactId>config-model</artifactId> <version>${project.version}</version> + <exclusions> + <exclusion> + <!-- Large, and installed separately as part of Vespa --> + <groupId>org.tensorflow</groupId> + <artifactId>libtensorflow_jni</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java index c65e0fad1c7..a2bdc6834c9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java @@ -12,12 +12,20 @@ import java.util.Objects; */ public class RankingConstant { + public enum PathType {FILE, URI}; + /** The search definition-unique name of this constant */ private final String name; private TensorType tensorType = null; - private String fileName = null; + private String path = null; private String fileReference = ""; + public PathType getPathType() { + return pathType; + } + + private PathType pathType = PathType.FILE; + public RankingConstant(String name) { this.name = name; } @@ -25,13 +33,20 @@ public class RankingConstant { public RankingConstant(String name, TensorType type, String fileName) { this(name); this.tensorType = type; - this.fileName = fileName; + this.path = fileName; validate(); } public void setFileName(String fileName) { Objects.requireNonNull(fileName, "Filename cannot be null"); - this.fileName = fileName; + this.path = fileName; + this.pathType = PathType.FILE; + } + + public void setUri(String uri) { + Objects.requireNonNull(uri, "uri cannot be null"); + this.path = uri; + this.pathType = PathType.URI; } /** @@ -43,14 +58,15 @@ public class RankingConstant { public void setType(TensorType tensorType) { this.tensorType = tensorType; } public String getName() { return name; } - public String getFileName() { return fileName; } + public String getFileName() { return path; } + public String getUri() { return path; } public String getFileReference() { return fileReference; } public TensorType getTensorType() { return tensorType; } public String getType() { return tensorType.toString(); } public void validate() { - if (fileName == null || fileName.isEmpty()) - throw new IllegalArgumentException("Ranking constants must have a file."); + if (path == null || path.isEmpty()) + throw new IllegalArgumentException("Ranking constants must have a file or uri."); if (tensorType == null) throw new IllegalArgumentException("Ranking constant '" + name + "' must have a type."); } @@ -58,7 +74,7 @@ public class RankingConstant { public String toString() { StringBuilder b = new StringBuilder(); b.append("constant '").append(name) - .append("' from file '").append(fileName) + .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) .append("' with ref '").append(fileReference) .append("' of type '").append(tensorType) .append("'"); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java index 316ff3dff40..7fcd2ed357a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.searchdefinition.RankProfile; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index c95601f6bbf..01d3449573c 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -33,7 +33,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.logging.Logger; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -46,8 +45,6 @@ import java.util.logging.Logger; // TODO: Avoid name conflicts across models for constants public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private static final Logger log = Logger.getLogger(TensorFlowFeatureConverter.class.getName()); - private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ @@ -68,11 +65,11 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), - feature.getArguments()); - if (store.hasTensorFlowModels()) - return transformFromTensorFlowModel(store, context.rankProfile()); - else // is should have previously stored model information instead + feature.getArguments()); + if (store.hasStoredModel()) return transformFromStoredModel(store, context.rankProfile()); + else // not converted yet - access TensorFlow model files + return transformFromTensorFlowModel(store, context.rankProfile()); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); @@ -185,12 +182,12 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil public FeatureArguments arguments() { return arguments; } - public boolean hasTensorFlowModels() { + public boolean hasStoredModel() { try { - return application.getFile(ApplicationPackage.MODELS_DIR).exists(); + return application.getFile(arguments.expressionPath()).exists(); } catch (UnsupportedOperationException e) { - return false; // No files -> no TensorFlow models + return false; } } @@ -206,7 +203,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil * Adds this expression to the application package, such that it can be read later. */ public void writeConverted(RankingExpression expression) { - log.info("Writing converted TensorFlow expression to " + arguments.expressionPath()); application.getFile(arguments.expressionPath()) .writeFile(new StringReader(expression.getRoot().toString())); } @@ -214,7 +210,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil /** Reads the previously stored ranking expression for these arguments */ public RankingExpression readConverted() { try { - log.info("Reading converted TensorFlow expression from " + arguments.expressionPath()); return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); } catch (IOException e) { @@ -261,12 +256,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } // Remember the constant in a file we replicate in ZooKeeper - log.info("Writing converted TensorFlow constant information to " + arguments.rankingConstantsPath().append(name + ".constant")); application.getFile(arguments.rankingConstantsPath().append(name + ".constant")) .writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPathCorrected)); // Write content explicitly as a file on the file system as this is distributed using file distribution - log.info("Writing converted TensorFlow constant to " + application.getFileReference(constantPath).getAbsolutePath()); createIfNeeded(constantsPath); IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); return constantPathCorrected; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/ImportedFieldsInSummayValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/ImportedFieldsInSummayValidator.java index eaa85815736..0c1e46a474c 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/ImportedFieldsInSummayValidator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/ImportedFieldsInSummayValidator.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.DeployLogger; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java b/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java index 3b75be5167d..65f7bbedc68 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java @@ -508,6 +508,13 @@ public abstract class AbstractService extends AbstractConfigProducer<AbstractCon throw new RuntimeException("File does not exist: '" + relativePath + "'."); } } + public FileReference sendUri(String uri) { + try { + return getRoot().getFileDistributor().sendUriToHost(uri, getHost()); + } catch (PathDoesNotExistException e) { + throw new RuntimeException("Uri does not exist: '" + uri + "'."); + } + } /** * diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java index 520ff231921..d022b2cf8ab 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java @@ -63,15 +63,19 @@ public class RankingConstantsValidator extends Validator { } private void validateRankingConstant(RankingConstant rankingConstant, ApplicationPackage application) throws FileNotFoundException { - String constantFile = rankingConstant.getFileName(); - if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) && - constantFile.startsWith(FilesApplicationPackage.preprocessed)) - constantFile = constantFile.substring(FilesApplicationPackage.preprocessed.length()); + // TODO: Handle validation of URI soon too. + if (rankingConstant.getPathType() == RankingConstant.PathType.FILE) { + String constantFile = rankingConstant.getFileName(); + if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) && + constantFile.startsWith(FilesApplicationPackage.preprocessed)) { + constantFile = constantFile.substring(FilesApplicationPackage.preprocessed.length()); + } - ApplicationFile tensorApplicationFile = application.getFile(Path.fromString(constantFile)); - new ConstantTensorJsonValidator().validate(constantFile, - rankingConstant.getTensorType(), - tensorApplicationFile.createReader()); + ApplicationFile tensorApplicationFile = application.getFile(Path.fromString(constantFile)); + new ConstantTensorJsonValidator().validate(constantFile, + rankingConstant.getTensorType(), + tensorApplicationFile.createReader()); + } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/FileDistributor.java b/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/FileDistributor.java index 213451da55e..e8d6a330358 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/FileDistributor.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/FileDistributor.java @@ -38,11 +38,31 @@ public class FileDistributor { return reference; } + /** + * Adds the given file to the associated application packages' registry of file and marks the file + * for distribution to the given hosts. + * <b>Note: This class receives ownership of the given collection.</b> + * + * @return the reference to the file, created by the application package + */ + public FileReference sendUriToHosts(String uri, Collection<Host> hosts) { + FileReference reference = fileRegistry.addUri(uri); + if (reference != null) { + addToFilesToDistribute(reference, hosts); + } + + return reference; + } + /** Same as sendFileToHost(relativePath,Collections.singletonList(host) */ public FileReference sendFileToHost(String relativePath, Host host) { return sendFileToHosts(relativePath, Arrays.asList(host)); } + public FileReference sendUriToHost(String uri, Host host) { + return sendUriToHosts(uri, Arrays.asList(host)); + } + private void addToFilesToDistribute(FileReference reference, Collection<Host> hosts) { Set<Host> oldHosts = getHosts(reference); oldHosts.addAll(hosts); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java index 58fc76f1508..9550cd82b22 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java @@ -38,7 +38,9 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer public void prepareToDistributeFiles(List<SearchNode> backends) { for (SearchDefinitionSpec sds : localSDS) { for (RankingConstant constant : sds.getSearchDefinition().getSearch().getRankingConstants().values()) { - FileReference reference = FileSender.sendFileToServices(constant.getFileName(), backends); + FileReference reference = (constant.getPathType() == RankingConstant.PathType.FILE) + ? FileSender.sendFileToServices(constant.getFileName(), backends) + : FileSender.sendUriToServices(constant.getUri(), backends); constant.setFileReference(reference.value()); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/utils/FileSender.java b/config-model/src/main/java/com/yahoo/vespa/model/utils/FileSender.java index 413363d7b0d..8995fcbca99 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/utils/FileSender.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/utils/FileSender.java @@ -20,6 +20,8 @@ import java.util.*; */ public class FileSender implements Serializable { + public enum FileType {FILE, URI}; + /** * Send the given file to all given services. * @@ -34,6 +36,7 @@ public class FileSender implements Serializable { throw new IllegalStateException("No service instances. Probably a standalone cluster setting up <nodes> " + "using 'count' instead of <node> tags."); } + FileReference fileref = null; for (AbstractService service : services) { // The same reference will be returned from each call. @@ -42,6 +45,20 @@ public class FileSender implements Serializable { return fileref; } + public static FileReference sendUriToServices(String uri, Collection<? extends AbstractService> services) { + if (services.isEmpty()) { + throw new IllegalStateException("No service instances. Probably a standalone cluster setting up <nodes> " + + "using 'count' instead of <node> tags."); + } + + FileReference fileref = null; + for (AbstractService service : services) { + // The same reference will be returned from each call. + fileref = service.sendUri(uri); + } + return fileref; + } + /** * Sends all user configured files for a producer to all given services. */ diff --git a/config-model/src/main/javacc/SDParser.jj b/config-model/src/main/javacc/SDParser.jj index 916a905dfcb..bf6376983a4 100644 --- a/config-model/src/main/javacc/SDParser.jj +++ b/config-model/src/main/javacc/SDParser.jj @@ -339,6 +339,7 @@ TOKEN : | < RANKSCOREDROPLIMIT: "rank-score-drop-limit" > | < CONSTANTS: "constants" > | < FILE: "file" > +| < URI: "uri" > | < IDENTIFIER: ["a"-"z","A"-"Z", "_"] (["a"-"z","A"-"Z","0"-"9","_","-"])* > | < QUOTEDSTRING: "\"" ( ~["\""] )* "\"" > | < CONTEXT: ["a"-"z","A"-"Z"] (["a"-"z", "A"-"Z", "0"-"9"])* > @@ -347,6 +348,8 @@ TOKEN : | < LONG: ("-")? (["0"-"9"])+"L" > | < STRING: (["a"-"z","A"-"Z","_","0"-"9","."])+ > | < FILE_PATH: ["a"-"z","A"-"Z", "_"] (["a"-"z","A"-"Z","0"-"9","_","-", "/", "."])+ > +| < HTTP: ["h","H"] ["t","T"] ["t","T"] ["p","P"] > +| < URI_PATH: <HTTP> <COLON> ("//")? (["a"-"z","A"-"Z","0"-"9","_","-", "/", ".",":"])+ > | < LESSTHAN: "<" > | < GREATERTHAN: ">" > | < VARIABLE: "$" <IDENTIFIER> > @@ -1805,11 +1808,12 @@ void rankingConstant(Search search) : */ Object rankingConstantItem(RankingConstant constant) : { - String fileName = null; + String path = null; TensorType type = null; } { - ( (<FILE> <COLON> fileName = filePath() { } (<NL>)*) { constant.setFileName(fileName); } + ( (<FILE> <COLON> path = filePath() { } (<NL>)*) { constant.setFileName(path); } + | (<URI> <COLON> path = uriPath() { } (<NL>)*) { constant.setUri(path); } | type = tensorTypeWithPrefix(rankingConstantErrorMessage(constant.getName())) (<NL>)* { constant.setType(type); } ) { @@ -1828,6 +1832,12 @@ String filePath() : { } { return token.image; } } +String uriPath() : { } +{ + ( <URI_PATH> ) + { return token.image; } +} + /** * Consumes a rank-profile block of a search element. * @@ -2550,6 +2560,7 @@ String identifier() : { } | <TRUE> | <TYPE> | <UCA> + | <URI> | <UPPERBOUND> | <USEDOCUMENT> | <VARIABLE> diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java index 2880af9e74f..9bad5373191 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.searchdefinition.parser.ParseException; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -48,6 +49,7 @@ public class RankingConstantTest { assertEquals(TENSOR_NAME, constant.getName()); assertEquals(TENSOR_FILE, constant.getFileName()); assertEquals(TENSOR_TYPE, constant.getType()); + assertEquals(RankingConstant.PathType.FILE, constant.getPathType()); assertFalse(constantIterator.hasNext()); } @@ -103,4 +105,80 @@ public class RankingConstantTest { assertEquals("simplename", constant.getFileName()); } + @Test + public void constant_uri_is_allowed() throws Exception { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); + searchBuilder.importString(joinLines( + "search test {", + " document test { }", + " constant foo {", + " type: tensor(x{})", + " uri: http://somwhere.far.away/in/another-galaxy", + " }", + "}" + )); + searchBuilder.build(); + Search search = searchBuilder.getSearch(); + RankingConstant constant = search.getRankingConstants().values().iterator().next(); + assertEquals(RankingConstant.PathType.URI, constant.getPathType()); + assertEquals("http://somwhere.far.away/in/another-galaxy", constant.getUri()); + } + @Test + public void constant_uri_with_port_is_allowed() throws Exception { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); + searchBuilder.importString(joinLines( + "search test {", + " document test { }", + " constant foo {", + " type: tensor(x{})", + " uri: http://somwhere.far.away:4080/in/another-galaxy", + " }", + "}" + )); + searchBuilder.build(); + Search search = searchBuilder.getSearch(); + RankingConstant constant = search.getRankingConstants().values().iterator().next(); + assertEquals(RankingConstant.PathType.URI, constant.getPathType()); + assertEquals("http://somwhere.far.away:4080/in/another-galaxy", constant.getUri()); + } + @Test + public void constant_uri_no_dual_slashes_is_allowed() throws Exception { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); + searchBuilder.importString(joinLines( + "search test {", + " document test { }", + " constant foo {", + " type: tensor(x{})", + " uri: http:somwhere.far.away/in/another-galaxy", + " }", + "}" + )); + searchBuilder.build(); + Search search = searchBuilder.getSearch(); + RankingConstant constant = search.getRankingConstants().values().iterator().next(); + assertEquals(RankingConstant.PathType.URI, constant.getPathType()); + assertEquals("http:somwhere.far.away/in/another-galaxy", constant.getUri()); + } + @Test + public void constant_uri_only_supports_http() throws Exception { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); + thrown.expect(ParseException.class); + thrown.expectMessage("Encountered \" <IDENTIFIER> \"ftp \"\" at line 5, column 10.\n" + + "Was expecting:\n" + + " <URI_PATH> ..."); + searchBuilder.importString(joinLines( + "search test {", + " document test { }", + " constant foo {", + " type: tensor(x{})", + " uri: ftp:somwhere.far.away/in/another-galaxy", + " }", + "}" + )); + } + } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImportedFieldsInSummaryValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImportedFieldsInSummaryValidatorTestCase.java index 4e1f8f1edd7..c600f447f01 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImportedFieldsInSummaryValidatorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImportedFieldsInSummaryValidatorTestCase.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; import com.yahoo.document.DataType; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 7c749608e1f..8ba0a42f799 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.ApplicationPackage; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/utils/DocType.java b/config-model/src/test/java/com/yahoo/vespa/model/content/utils/DocType.java index 3a5f679509b..cc5ec34b43e 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/utils/DocType.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/utils/DocType.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.content.utils; import java.util.Arrays; diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/AddFileInterface.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/AddFileInterface.java index 61c376a7256..39bf2ee02fd 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/AddFileInterface.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/AddFileInterface.java @@ -4,6 +4,8 @@ package com.yahoo.vespa.config.server.filedistribution; import com.yahoo.config.FileReference; public interface AddFileInterface { + FileReference addUri(String uri, String relativePath); + FileReference addUri(String uri, String relativePath, FileReference reference); FileReference addFile(String relativePath); FileReference addFile(String relativePath, FileReference reference); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/ApplicationFileManager.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/ApplicationFileManager.java index 82535143c89..0e7aa4a4fd2 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/ApplicationFileManager.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/ApplicationFileManager.java @@ -3,6 +3,12 @@ package com.yahoo.vespa.config.server.filedistribution; import com.yahoo.config.FileReference; import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.net.URL; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.file.Files; public class ApplicationFileManager implements AddFileInterface { @@ -24,4 +30,41 @@ public class ApplicationFileManager implements AddFileInterface { return master.addFile(new File(applicationDir, relativePath)); } + @Override + public FileReference addUri(String uri, String relativePath) { + download(uri, relativePath); + return addFile(relativePath); + } + + @Override + public FileReference addUri(String uri, String relativePath, FileReference reference) { + download(uri, relativePath); + return addFile(relativePath, reference); + } + + void download(String uri, String relativePath) { + File file = new File(applicationDir, relativePath); + FileOutputStream fos = null; + ReadableByteChannel rbc = null; + try { + Files.createDirectories(file.toPath().getParent()); + URL website = new URL(uri); + rbc = Channels.newChannel(website.openStream()); + fos = new FileOutputStream(file.getAbsolutePath()); + fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE); + } catch (IOException e) { + throw new IllegalArgumentException("Failed creating directory " + file.getParent(), e); + } finally { + try { + if (fos != null) { + fos.close(); + } + if (rbc != null) { + rbc.close(); + } + } catch (IOException e) { + throw new IllegalArgumentException("Failed closing down after downloading " + uri + " to " + file.getAbsolutePath()); + } + } + } } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyRegistry.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyRegistry.java index 8f2cb194bbd..1fe72e27461 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyRegistry.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyRegistry.java @@ -21,6 +21,12 @@ public class CombinedLegacyRegistry implements FileRegistry { } @Override + public FileReference addUri(String uri) { + FileReference reference = future.addUri(uri); + return reference; + } + + @Override public String fileSourceHost() { return future.fileSourceHost(); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDBRegistry.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDBRegistry.java index 1a76454fbed..b0a802d831d 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDBRegistry.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDBRegistry.java @@ -4,7 +4,10 @@ package com.yahoo.vespa.config.server.filedistribution; import com.yahoo.config.FileReference; import com.yahoo.config.application.api.FileRegistry; import com.yahoo.net.HostName; +import com.yahoo.text.Utf8; +import net.jpountz.xxhash.XXHashFactory; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -34,6 +37,17 @@ public class FileDBRegistry implements FileRegistry { }); } + public synchronized FileReference addUri(String uri, FileReference reference) { + String relativePath = uriToRelativeFile(uri); + Optional<FileReference> cachedReference = Optional.ofNullable(fileReferenceCache.get(uri)); + return cachedReference.orElseGet(() -> { + FileReference newRef = manager.addUri(uri, relativePath, reference); + entries.add(new Entry(uri, newRef)); + fileReferenceCache.put(uri, newRef); + return newRef; + }); + } + @Override public synchronized FileReference addFile(String relativePath) { Optional<FileReference> cachedReference = Optional.ofNullable(fileReferenceCache.get(relativePath)); @@ -46,6 +60,18 @@ public class FileDBRegistry implements FileRegistry { } @Override + public synchronized FileReference addUri(String uri) { + String relativePath = uriToRelativeFile(uri); + Optional<FileReference> cachedReference = Optional.ofNullable(fileReferenceCache.get(uri)); + return cachedReference.orElseGet(() -> { + FileReference newRef = manager.addUri(uri, relativePath); + entries.add(new Entry(uri, newRef)); + fileReferenceCache.put(uri, newRef); + return newRef; + }); + } + + @Override public String fileSourceHost() { return HostName.getLocalhost(); } @@ -55,4 +81,16 @@ public class FileDBRegistry implements FileRegistry { return entries; } + private static String uriToRelativeFile(String uri) { + String relative = "uri/" + String.valueOf(XXHashFactory.nativeInstance().hash64().hash(ByteBuffer.wrap(Utf8.toBytes(uri)), 0)); + if (uri.endsWith(".json")) { + relative += ".json"; + } else if (uri.endsWith(".json.lz4")) { + relative += ".json.lz4"; + } else if (uri.endsWith(".lz4")) { + relative += ".lz4"; + } + return relative; + } + } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java index 6c2da338ef0..117bf3e236b 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java @@ -26,6 +26,17 @@ public class FileDistributionProvider { ManagerWrapper(FileDistributionManager manager) { this.manager = manager; } + + @Override + public FileReference addUri(String uri, String relativePath) { + throw new IllegalStateException("addUri is not possible with legacy filedistribution."); + } + + @Override + public FileReference addUri(String uri, String relativePath, FileReference reference) { + throw new IllegalStateException("addUri is not possible with legacy filedistribution."); + } + @Override public FileReference addFile(String relativePath) { return new FileReference(manager.addFile(relativePath)); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileServer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileServer.java index 906182396d9..dbd8fdda052 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileServer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileServer.java @@ -120,6 +120,7 @@ public class FileServer { } target.receive(fileData, new ReplayStatus(success ? 0 : 1, success ? "OK" : errorDescription)); + fileData.close(); log.log(LogLevel.DEBUG, "Done serving reference '" + reference.toString() + "' with file '" + file.getAbsolutePath() + "'"); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/status/StatusHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/status/StatusHandler.java index fb1108c18c2..5c0439c0af3 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/status/StatusHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/status/StatusHandler.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.http.status; import com.google.inject.Inject; diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java index ea8405f6b65..731e343532a 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java @@ -104,6 +104,7 @@ public abstract class ModelsBuilder<MODELRESULT extends ModelResult> { log.log(LogLevel.WARNING, "Unexpected error when building model ", e); throw new InternalServerException(applicationId + ": Error loading model", e); } else { + log.log(LogLevel.WARNING, "Input error when building model ", e); throw new IllegalArgumentException(applicationId + ": Error loading model", e); } } else { diff --git a/configserver/src/test/resources/deploy/advancedapp/deployment.xml b/configserver/src/test/resources/deploy/advancedapp/deployment.xml index fa1d1388e67..46451eb3787 100644 --- a/configserver/src/test/resources/deploy/advancedapp/deployment.xml +++ b/configserver/src/test/resources/deploy/advancedapp/deployment.xml @@ -1 +1,2 @@ -<deployment version='1.0'/>
\ No newline at end of file +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<deployment version='1.0'/> diff --git a/configserver/src/test/resources/deploy/app/deployment.xml b/configserver/src/test/resources/deploy/app/deployment.xml index fa1d1388e67..46451eb3787 100644 --- a/configserver/src/test/resources/deploy/app/deployment.xml +++ b/configserver/src/test/resources/deploy/app/deployment.xml @@ -1 +1,2 @@ -<deployment version='1.0'/>
\ No newline at end of file +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<deployment version='1.0'/> diff --git a/configserver/src/test/resources/deploy/validapp/deployment.xml b/configserver/src/test/resources/deploy/validapp/deployment.xml index fa1d1388e67..46451eb3787 100644 --- a/configserver/src/test/resources/deploy/validapp/deployment.xml +++ b/configserver/src/test/resources/deploy/validapp/deployment.xml @@ -1 +1,2 @@ -<deployment version='1.0'/>
\ No newline at end of file +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<deployment version='1.0'/> diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/state/StateMonitor.java b/container-core/src/main/java/com/yahoo/container/jdisc/state/StateMonitor.java index 6234a96d7a0..6ccd25ad6c7 100644 --- a/container-core/src/main/java/com/yahoo/container/jdisc/state/StateMonitor.java +++ b/container-core/src/main/java/com/yahoo/container/jdisc/state/StateMonitor.java @@ -31,7 +31,7 @@ public class StateMonitor extends AbstractComponent { private final Thread thread; private final Timer timer; private final long snapshotIntervalMs; - private long lastSnapshotTimeMs; + private volatile long lastSnapshotTimeMs; private volatile MetricSnapshot snapshot; private volatile Status status; private final TreeSet<String> valueNames = new TreeSet<>(); diff --git a/container-core/src/test/java/com/yahoo/container/jdisc/state/StateHandlerTest.java b/container-core/src/test/java/com/yahoo/container/jdisc/state/StateHandlerTest.java index 100650d43bd..5d8f885f8d0 100644 --- a/container-core/src/test/java/com/yahoo/container/jdisc/state/StateHandlerTest.java +++ b/container-core/src/test/java/com/yahoo/container/jdisc/state/StateHandlerTest.java @@ -30,6 +30,7 @@ import java.util.HashMap; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -46,7 +47,7 @@ public class StateHandlerTest { private TestDriver driver; private StateMonitor monitor; private Metric metric; - private volatile long currentTimeMillis = 0; + private final AtomicLong currentTimeMillis = new AtomicLong(0); @Before public void startTestDriver() { @@ -58,7 +59,7 @@ public class StateHandlerTest { @Override public long currentTimeMillis() { - return currentTimeMillis; + return currentTimeMillis.get(); } }); } @@ -400,7 +401,7 @@ public class StateHandlerTest { } private void incrementCurrentTime(long val) { - currentTimeMillis += val; + currentTimeMillis.addAndGet(val); monitor.checkTime(); } diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtilsTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtilsTest.java index 2a265a3c6fd..dc9690355e8 100644 --- a/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtilsTest.java +++ b/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtilsTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.container.jdisc.athenz.impl; import org.bouncycastle.pkcs.PKCS10CertificationRequest; diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java index b2349ed6dfc..20f87afacc1 100644 --- a/container-search/src/main/java/com/yahoo/search/Query.java +++ b/container-search/src/main/java/com/yahoo/search/Query.java @@ -464,7 +464,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { QueryProfileProperties queryProfileProperties = properties().getInstance(QueryProfileProperties.class); if (queryProfileProperties == null) return null; // Valid StringBuilder missingName = new StringBuilder(); - if (! queryProfileProperties.isComplete(missingName, httpRequest.propertyMap())) + if ( ! queryProfileProperties.isComplete(missingName, httpRequest.propertyMap())) return "Incomplete query: Parameter '" + missingName + "' is mandatory in " + queryProfileProperties.getQueryProfile() + " but is not set"; else diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/AllTypesQueryProfileVisitor.java b/container-search/src/main/java/com/yahoo/search/query/profile/AllTypesQueryProfileVisitor.java index c194ec235bd..4b83b716635 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/AllTypesQueryProfileVisitor.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/AllTypesQueryProfileVisitor.java @@ -32,7 +32,7 @@ final class AllTypesQueryProfileVisitor extends PrefixQueryProfileVisitor { } private void addReachableTypes(CompoundName name, QueryProfileType type) { - types.put(name, type); + types.putIfAbsent(name, type); // Types visited earlier has precedence: profile.type overrides profile.inherited.type for (FieldDescription fieldDescription : type.fields().values()) { if ( ! (fieldDescription.getType() instanceof QueryProfileFieldType)) continue; QueryProfileFieldType fieldType = (QueryProfileFieldType)fieldDescription.getType(); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java index ecbdebf1524..04dd3ee9005 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java @@ -12,7 +12,12 @@ import com.yahoo.search.query.profile.types.FieldDescription; import com.yahoo.search.query.profile.types.QueryProfileFieldType; import com.yahoo.search.query.profile.types.QueryProfileType; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/MandatoryTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/MandatoryTestCase.java index b91bcbfba69..7dc6eb3d8aa 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/MandatoryTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/MandatoryTestCase.java @@ -1,9 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.query.profile.types.test; -import com.yahoo.jdisc.http.HttpRequest.Method; -import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.component.ComponentId; +import com.yahoo.container.jdisc.HttpRequest; +import com.yahoo.jdisc.http.HttpRequest.Method; import com.yahoo.search.Query; import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; @@ -13,79 +13,84 @@ import com.yahoo.search.query.profile.types.FieldType; import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry; import com.yahoo.search.test.QueryTestCase; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; /** * @author bratseth */ -public class MandatoryTestCase extends junit.framework.TestCase { - - private QueryProfileTypeRegistry registry; - - private QueryProfileType type, user; - - protected @Override void setUp() { - type=new QueryProfileType(new ComponentId("testtype")); - user=new QueryProfileType(new ComponentId("user")); - registry=new QueryProfileTypeRegistry(); - registry.register(type); - registry.register(user); +public class MandatoryTestCase { + + private static class Fixture1 { + + final QueryProfileRegistry registry = new QueryProfileRegistry(); + final QueryProfileTypeRegistry typeRegistry = new QueryProfileTypeRegistry(); + final QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); + final QueryProfileType user = new QueryProfileType(new ComponentId("user")); + + public Fixture1() { + typeRegistry.register(type); + typeRegistry.register(user); + + addTypeFields(type, typeRegistry); + addUserFields(user, typeRegistry); + } + + private static void addTypeFields(QueryProfileType type, QueryProfileTypeRegistry registry) { + type.addField(new FieldDescription("myString", FieldType.fromString("string", registry), true)); + type.addField(new FieldDescription("myInteger", FieldType.fromString("integer", registry))); + type.addField(new FieldDescription("myLong", FieldType.fromString("long", registry))); + type.addField(new FieldDescription("myFloat", FieldType.fromString("float", registry))); + type.addField(new FieldDescription("myDouble", FieldType.fromString("double", registry))); + type.addField(new FieldDescription("myQueryProfile", FieldType.fromString("query-profile", registry))); + type.addField(new FieldDescription("myUserQueryProfile", FieldType.fromString("query-profile:user", registry), true)); + } + + private static void addUserFields(QueryProfileType user, QueryProfileTypeRegistry registry) { + user.addField(new FieldDescription("myUserString", FieldType.fromString("string", registry), true)); + user.addField(new FieldDescription("myUserInteger", FieldType.fromString("integer", registry), true)); + } - addTypeFields(type); - addUserFields(user); - } - - private void addTypeFields(QueryProfileType type) { - boolean mandatory=true; - type.addField(new FieldDescription("myString", FieldType.fromString("string",registry), mandatory)); - type.addField(new FieldDescription("myInteger",FieldType.fromString("integer",registry))); - type.addField(new FieldDescription("myLong",FieldType.fromString("long",registry))); - type.addField(new FieldDescription("myFloat",FieldType.fromString("float",registry))); - type.addField(new FieldDescription("myDouble",FieldType.fromString("double",registry))); - type.addField(new FieldDescription("myQueryProfile",FieldType.fromString("query-profile",registry))); - type.addField(new FieldDescription("myUserQueryProfile", FieldType.fromString("query-profile:user",registry),mandatory)); - } - - private void addUserFields(QueryProfileType user) { - boolean mandatory=true; - user.addField(new FieldDescription("myUserString",FieldType.fromString("string",registry),mandatory)); - user.addField(new FieldDescription("myUserInteger",FieldType.fromString("integer",registry),mandatory)); } + @Test public void testMandatoryFullySpecifiedQueryProfile() { - QueryProfileRegistry registry = new QueryProfileRegistry(); + Fixture1 fixture = new Fixture1(); QueryProfile test=new QueryProfile("test"); - test.setType(type); - test.set("myString","aString", registry); - registry.register(test); + test.setType(fixture.type); + test.set("myString", "aString", fixture.registry); + fixture.registry.register(test); QueryProfile myUser=new QueryProfile("user"); - myUser.setType(user); - myUser.set("myUserInteger",1, registry); - myUser.set("myUserString",1, registry); - test.set("myUserQueryProfile", myUser, registry); - registry.register(myUser); + myUser.setType(fixture.user); + myUser.set("myUserInteger",1, fixture.registry); + myUser.set("myUserString",1, fixture.registry); + test.set("myUserQueryProfile", myUser, fixture.registry); + fixture.registry.register(myUser); - CompiledQueryProfileRegistry cRegistry = registry.compile(); + CompiledQueryProfileRegistry cRegistry = fixture.registry.compile(); // Fully specified request assertError(null, new Query(QueryTestCase.httpEncode("?queryProfile=test"), cRegistry.getComponent("test"))); } + @Test public void testMandatoryRequestPropertiesNeeded() { - QueryProfileRegistry registry = new QueryProfileRegistry(); + Fixture1 fixture = new Fixture1(); - QueryProfile test=new QueryProfile("test"); - test.setType(type); - registry.register(test); + QueryProfile test = new QueryProfile("test"); + test.setType(fixture.type); + fixture.registry.register(test); - QueryProfile myUser=new QueryProfile("user"); - myUser.setType(user); - myUser.set("myUserInteger",1, registry); - test.set("myUserQueryProfile",myUser, registry); - registry.register(myUser); + QueryProfile myUser = new QueryProfile("user"); + myUser.setType(fixture.user); + myUser.set("myUserInteger", 1, fixture.registry); + test.set("myUserQueryProfile", myUser, fixture.registry); + fixture.registry.register(myUser); - CompiledQueryProfileRegistry cRegistry = registry.compile(); + CompiledQueryProfileRegistry cRegistry = fixture.registry.compile(); // Underspecified request 1 assertError("Incomplete query: Parameter 'myString' is mandatory in query profile 'test' of type 'testtype' but is not set", @@ -100,29 +105,30 @@ public class MandatoryTestCase extends junit.framework.TestCase { } /** Same as above except the whole thing is nested in maps */ + @Test public void testMandatoryNestedInMaps() { - QueryProfileRegistry registry = new QueryProfileRegistry(); + Fixture1 fixture = new Fixture1(); - QueryProfile topMap=new QueryProfile("topMap"); - registry.register(topMap); + QueryProfile topMap = new QueryProfile("topMap"); + fixture.registry.register(topMap); - QueryProfile subMap=new QueryProfile("topSubMap"); - topMap.set("subMap",subMap, registry); - registry.register(subMap); + QueryProfile subMap = new QueryProfile("topSubMap"); + topMap.set("subMap", subMap, fixture.registry); + fixture.registry.register(subMap); - QueryProfile test=new QueryProfile("test"); - test.setType(type); - subMap.set("test",test, registry); - registry.register(test); + QueryProfile test = new QueryProfile("test"); + test.setType(fixture.type); + subMap.set("test", test, fixture.registry); + fixture.registry.register(test); - QueryProfile myUser=new QueryProfile("user"); - myUser.setType(user); - myUser.set("myUserInteger",1, registry); - test.set("myUserQueryProfile",myUser, registry); - registry.register(myUser); + QueryProfile myUser = new QueryProfile("user"); + myUser.setType(fixture.user); + myUser.set("myUserInteger",1, fixture.registry); + test.set("myUserQueryProfile", myUser, fixture.registry); + fixture.registry.register(myUser); - CompiledQueryProfileRegistry cRegistry = registry.compile(); + CompiledQueryProfileRegistry cRegistry = fixture.registry.compile(); // Underspecified request 1 assertError("Incomplete query: Parameter 'subMap.test.myString' is mandatory in query profile 'topMap' but is not set", @@ -137,13 +143,16 @@ public class MandatoryTestCase extends junit.framework.TestCase { } /** Here, no user query profile is referenced in the query profile, but one is chosen in the request */ + @Test public void testMandatoryUserProfileSetInRequest() { - QueryProfile test=new QueryProfile("test"); - test.setType(type); + Fixture1 fixture = new Fixture1(); - QueryProfile myUser=new QueryProfile("user"); - myUser.setType(user); - myUser.set("myUserInteger",1, (QueryProfileRegistry)null); + QueryProfile test = new QueryProfile("test"); + test.setType(fixture.type); + + QueryProfile myUser = new QueryProfile("user"); + myUser.setType(fixture.user); + myUser.set("myUserInteger", 1, null); QueryProfileRegistry registry = new QueryProfileRegistry(); registry.register(test); @@ -163,25 +172,27 @@ public class MandatoryTestCase extends junit.framework.TestCase { } /** Here, a partially specified query profile is added to a non-mandatory field, making the request underspecified */ + @Test public void testNonMandatoryUnderspecifiedUserProfileSetInRequest() { - QueryProfileRegistry registry = new QueryProfileRegistry(); + Fixture1 fixture = new Fixture1(); + QueryProfile test = new QueryProfile("test"); - test.setType(type); - registry.register(test); + test.setType(fixture.type); + fixture.registry.register(test); - QueryProfile myUser=new QueryProfile("user"); - myUser.setType(user); - myUser.set("myUserInteger", 1, registry); - myUser.set("myUserString","userValue", registry); - test.set("myUserQueryProfile",myUser, registry); - registry.register(myUser); + QueryProfile myUser = new QueryProfile("user"); + myUser.setType(fixture.user); + myUser.set("myUserInteger", 1, fixture.registry); + myUser.set("myUserString", "userValue", fixture.registry); + test.set("myUserQueryProfile", myUser, fixture.registry); + fixture.registry.register(myUser); - QueryProfile otherUser=new QueryProfile("otherUser"); - otherUser.setType(user); - otherUser.set("myUserInteger", 2, registry); - registry.register(otherUser); + QueryProfile otherUser = new QueryProfile("otherUser"); + otherUser.setType(fixture.user); + otherUser.set("myUserInteger", 2, fixture.registry); + fixture.registry.register(otherUser); - CompiledQueryProfileRegistry cRegistry = registry.compile(); + CompiledQueryProfileRegistry cRegistry = fixture.registry.compile(); // Fully specified request assertError(null, new Query(HttpRequest.createTestRequest("?myString=aString", Method.GET), cRegistry.getComponent("test"))); @@ -194,6 +205,62 @@ public class MandatoryTestCase extends junit.framework.TestCase { assertError(null, new Query(HttpRequest.createTestRequest("?myString=aString&myQueryProfile=otherUser&myQueryProfile.myUserString=userString", Method.GET), cRegistry.getComponent("test"))); } + private static class Fixture2 { + + final QueryProfileRegistry registry = new QueryProfileRegistry(); + final QueryProfileTypeRegistry typeRegistry = new QueryProfileTypeRegistry(); + final QueryProfileType rootType = new QueryProfileType(new ComponentId("root")); + final QueryProfileType mandatoryType = new QueryProfileType(new ComponentId("mandatory-type")); + + public Fixture2() { + typeRegistry.register(rootType); + typeRegistry.register(mandatoryType); + + mandatoryType.inherited().add(rootType); + mandatoryType.addField(new FieldDescription("foobar", FieldType.fromString("string", typeRegistry), true)); + } + + } + + @Test + public void testMandatoryInParentType() { + Fixture2 fixture = new Fixture2(); + + QueryProfile defaultProfile = new QueryProfile("default"); + defaultProfile.setType(fixture.rootType); + + QueryProfile mandatoryProfile = new QueryProfile("mandatory"); + mandatoryProfile.setType(fixture.rootType); + mandatoryProfile.setType(fixture.mandatoryType); + + fixture.registry.register(defaultProfile); + fixture.registry.register(mandatoryProfile); + CompiledQueryProfileRegistry cRegistry = fixture.registry.compile(); + + assertError("Incomplete query: Parameter 'foobar' is mandatory in query profile 'mandatory' of type 'mandatory-type' but is not set", + new Query(QueryTestCase.httpEncode("?queryProfile=mandatory"), cRegistry.getComponent("mandatory"))); + } + + @Test + public void testMandatoryInParentTypeWithInheritance() { + Fixture2 fixture = new Fixture2(); + + QueryProfile defaultProfile = new QueryProfile("default"); + defaultProfile.setType(fixture.rootType); + + QueryProfile mandatoryProfile = new QueryProfile("mandatory"); + mandatoryProfile.setType(fixture.rootType); + mandatoryProfile.addInherited(defaultProfile); // The single difference from the test above + mandatoryProfile.setType(fixture.mandatoryType); + + fixture.registry.register(defaultProfile); + fixture.registry.register(mandatoryProfile); + CompiledQueryProfileRegistry cRegistry = fixture.registry.compile(); + + assertError("Incomplete query: Parameter 'foobar' is mandatory in query profile 'mandatory' of type 'mandatory-type' but is not set", + new Query(QueryTestCase.httpEncode("?queryProfile=mandatory"), cRegistry.getComponent("mandatory"))); + } + private void assertError(String message,Query query) { assertEquals(message, query.validate()); } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/OwnershipIssues.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/OwnershipIssues.java index 8ded0c5fb52..91b5eb89c38 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/OwnershipIssues.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/OwnershipIssues.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.organization; import com.yahoo.config.provision.ApplicationId; diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/security/KeyServiceMock.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/security/KeyServiceMock.java index 46fa2a593c5..d2a4b675f6d 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/security/KeyServiceMock.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/security/KeyServiceMock.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.security; /** diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/DummyOwnershipIssues.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/DummyOwnershipIssues.java index 0cf103739d1..6e4761d1cf8 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/DummyOwnershipIssues.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/stubs/DummyOwnershipIssues.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.stubs; import com.yahoo.config.provision.ApplicationId; diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneFilter.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneFilter.java index f718b86ca40..d8a3fa1ce96 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneFilter.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneFilter.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.zone; /** diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneFilterMock.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneFilterMock.java index a7d51fa4d24..70e90554735 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneFilterMock.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneFilterMock.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.zone; import com.yahoo.config.provision.Environment; diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneList.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneList.java index 408168e41da..27e8a598043 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneList.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/zone/ZoneList.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.zone; import com.yahoo.config.provision.Environment; diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java index b75f80917a9..d5ce613b98d 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java @@ -11,9 +11,8 @@ import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService.ApplicationMetrics; import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId; import com.yahoo.vespa.hosted.controller.application.ApplicationRotation; -import com.yahoo.vespa.hosted.controller.application.ApplicationRevision; +import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; -import com.yahoo.vespa.hosted.controller.application.Change.VersionChange; import com.yahoo.vespa.hosted.controller.application.Deployment; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs; import com.yahoo.vespa.hosted.controller.rotation.RotationId; @@ -30,9 +29,9 @@ import java.util.stream.Collectors; /** * An instance of an application. - * + * * This is immutable. - * + * * @author bratseth */ public class Application { @@ -42,7 +41,7 @@ public class Application { private final ValidationOverrides validationOverrides; private final Map<ZoneId, Deployment> deployments; private final DeploymentJobs deploymentJobs; - private final Optional<Change> deploying; + private final Change change; private final boolean outstandingChange; private final Optional<IssueId> ownershipIssueId; private final ApplicationMetrics metrics; @@ -52,22 +51,22 @@ public class Application { public Application(ApplicationId id) { this(id, DeploymentSpec.empty, ValidationOverrides.empty, Collections.emptyMap(), new DeploymentJobs(Optional.empty(), Collections.emptyList(), Optional.empty()), - Optional.empty(), false, Optional.empty(), new ApplicationMetrics(0, 0), + Change.empty(), false, Optional.empty(), new ApplicationMetrics(0, 0), Optional.empty()); } /** Used from persistence layer: Do not use */ - public Application(ApplicationId id, DeploymentSpec deploymentSpec, ValidationOverrides validationOverrides, - List<Deployment> deployments, DeploymentJobs deploymentJobs, Optional<Change> deploying, + public Application(ApplicationId id, DeploymentSpec deploymentSpec, ValidationOverrides validationOverrides, + List<Deployment> deployments, DeploymentJobs deploymentJobs, Change change, boolean outstandingChange, Optional<IssueId> ownershipIssueId, ApplicationMetrics metrics, Optional<RotationId> rotation) { - this(id, deploymentSpec, validationOverrides, + this(id, deploymentSpec, validationOverrides, deployments.stream().collect(Collectors.toMap(Deployment::zone, d -> d)), - deploymentJobs, deploying, outstandingChange, ownershipIssueId, metrics, rotation); + deploymentJobs, change, outstandingChange, ownershipIssueId, metrics, rotation); } Application(ApplicationId id, DeploymentSpec deploymentSpec, ValidationOverrides validationOverrides, - Map<ZoneId, Deployment> deployments, DeploymentJobs deploymentJobs, Optional<Change> deploying, + Map<ZoneId, Deployment> deployments, DeploymentJobs deploymentJobs, Change change, boolean outstandingChange, Optional<IssueId> ownershipIssueId, ApplicationMetrics metrics, Optional<RotationId> rotation) { Objects.requireNonNull(id, "id cannot be null"); @@ -75,7 +74,7 @@ public class Application { Objects.requireNonNull(validationOverrides, "validationOverrides cannot be null"); Objects.requireNonNull(deployments, "deployments cannot be null"); Objects.requireNonNull(deploymentJobs, "deploymentJobs cannot be null"); - Objects.requireNonNull(deploying, "deploying cannot be null"); + Objects.requireNonNull(change, "change cannot be null"); Objects.requireNonNull(metrics, "metrics cannot be null"); Objects.requireNonNull(rotation, "rotation cannot be null"); this.id = id; @@ -83,7 +82,7 @@ public class Application { this.validationOverrides = validationOverrides; this.deployments = ImmutableMap.copyOf(deployments); this.deploymentJobs = deploymentJobs; - this.deploying = deploying; + this.change = change; this.outstandingChange = outstandingChange; this.ownershipIssueId = ownershipIssueId; this.metrics = metrics; @@ -91,24 +90,24 @@ public class Application { } public ApplicationId id() { return id; } - - /** - * Returns the last deployed deployment spec of this application, - * or the empty deployment spec if it has never been deployed + + /** + * Returns the last deployed deployment spec of this application, + * or the empty deployment spec if it has never been deployed */ public DeploymentSpec deploymentSpec() { return deploymentSpec; } /** - * Returns the last deployed validation overrides of this application, + * Returns the last deployed validation overrides of this application, * or the empty validation overrides if it has never been deployed * (or was deployed with an empty/missing validation overrides) */ public ValidationOverrides validationOverrides() { return validationOverrides; } - + /** Returns an immutable map of the current deployments of this */ public Map<ZoneId, Deployment> deployments() { return deployments; } - /** + /** * Returns an immutable map of the current *production* deployments of this * (deployments also includes manually deployed environments) */ @@ -121,10 +120,10 @@ public class Application { public DeploymentJobs deploymentJobs() { return deploymentJobs; } /** - * Returns the change that is currently in the process of being deployed on this application, - * or empty if no change is currently being deployed. + * Returns the change that should currently be deployed for this application, + * which is empty when no change is in progress. */ - public Optional<Change> deploying() { return deploying; } + public Change change() { return change; } /** * Returns whether this has an outstanding change (in the source repository), which @@ -152,10 +151,7 @@ public class Application { /** Returns the version a new deployment to this zone should use for this application */ public Version deployVersionIn(ZoneId zone, Controller controller) { - if (deploying().isPresent() && deploying().get() instanceof VersionChange) - return ((Change.VersionChange) deploying().get()).version(); - - return versionIn(zone, controller); + return change.platform().orElse(versionIn(zone, controller)); } /** Returns the current version this application has, or if none; should use, in the given zone */ @@ -164,17 +160,22 @@ public class Application { .orElse(oldestDeployedVersion().orElse(controller.systemVersion())); } - /** Returns the revision a new deployment to this zone should use for this application, or empty if we don't know */ - public Optional<ApplicationRevision> deployRevisionIn(ZoneId zone) { - if (deploying().isPresent() && deploying().get() instanceof Change.ApplicationChange) - return ((Change.ApplicationChange) deploying().get()).revision(); + /** Returns the application version a deployment to this zone should use, or empty if we don't know */ + public Optional<ApplicationVersion> deployApplicationVersionIn(ZoneId zone) { + if (change().application().isPresent()) { + ApplicationVersion version = change().application().get(); + if (version == ApplicationVersion.unknown) + return Optional.empty(); + else + return Optional.of(version); + } - return revisionIn(zone); + return applicationVersionIn(zone); } - /** Returns the revision this application is or should be deployed with in the given zone, or empty if unknown. */ - public Optional<ApplicationRevision> revisionIn(ZoneId zone) { - return Optional.ofNullable(deployments().get(zone)).map(Deployment::revision); + /** Returns the application version that is or should be deployed with in the given zone, or empty if unknown. */ + public Optional<ApplicationVersion> applicationVersionIn(ZoneId zone) { + return Optional.ofNullable(deployments().get(zone)).map(Deployment::applicationVersion); } /** Returns the global rotation of this, if present */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java index 2c2dcfe549b..83c3d6a5d11 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java @@ -7,7 +7,7 @@ import com.yahoo.config.application.api.ValidationId; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.TenantName; -import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; +import com.yahoo.vespa.athenz.api.NToken; import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.hosted.controller.api.ActivateResult; import com.yahoo.vespa.hosted.controller.api.InstanceEndpoints; @@ -22,13 +22,13 @@ import com.yahoo.vespa.hosted.controller.api.identifiers.Hostname; import com.yahoo.vespa.hosted.controller.api.identifiers.RevisionId; import com.yahoo.vespa.hosted.controller.api.identifiers.TenantId; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory; -import com.yahoo.vespa.athenz.api.NToken; import com.yahoo.vespa.hosted.controller.api.integration.athenz.ZmsClient; import com.yahoo.vespa.hosted.controller.api.integration.athenz.ZmsException; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServerClient; import com.yahoo.vespa.hosted.controller.api.integration.configserver.Log; import com.yahoo.vespa.hosted.controller.api.integration.configserver.NoInstanceException; import com.yahoo.vespa.hosted.controller.api.integration.configserver.PrepareResponse; +import com.yahoo.vespa.hosted.controller.api.integration.deployment.ArtifactRepository; import com.yahoo.vespa.hosted.controller.api.integration.dns.NameService; import com.yahoo.vespa.hosted.controller.api.integration.dns.Record; import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordData; @@ -36,9 +36,9 @@ import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordId; import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordName; import com.yahoo.vespa.hosted.controller.api.integration.routing.RoutingEndpoint; import com.yahoo.vespa.hosted.controller.api.integration.routing.RoutingGenerator; +import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; -import com.yahoo.vespa.hosted.controller.application.ApplicationRevision; -import com.yahoo.vespa.hosted.controller.application.Change; +import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Deployment; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobReport; @@ -90,6 +90,7 @@ public class ApplicationController { /** For working memory storage and sharing between controllers */ private final CuratorDb curator; + private final ArtifactRepository artifactRepository; private final RotationRepository rotationRepository; private final AthenzClientFactory zmsClientFactory; private final NameService nameService; @@ -102,6 +103,7 @@ public class ApplicationController { ApplicationController(Controller controller, ControllerDb db, CuratorDb curator, AthenzClientFactory zmsClientFactory, RotationsConfig rotationsConfig, NameService nameService, ConfigServerClient configserverClient, + ArtifactRepository artifactRepository, RoutingGenerator routingGenerator, Clock clock) { this.controller = controller; this.db = db; @@ -112,6 +114,7 @@ public class ApplicationController { this.routingGenerator = routingGenerator; this.clock = clock; + this.artifactRepository = artifactRepository; this.rotationRepository = new RotationRepository(rotationsConfig, this, curator); this.deploymentTrigger = new DeploymentTrigger(controller, curator, clock); @@ -271,7 +274,8 @@ public class ApplicationController { /** Deploys an application. If the application does not exist it is created. */ // TODO: Get rid of the options arg public ActivateResult deployApplication(ApplicationId applicationId, ZoneId zone, - ApplicationPackage applicationPackage, DeployOptions options) { + Optional<ApplicationPackage> applicationPackageFromDeployer, + DeployOptions options) { try (Lock lock = lock(applicationId)) { // TODO: Move application creation outside, to the deploy call in the handler. LockedApplication application = get(applicationId) @@ -281,40 +285,66 @@ public class ApplicationController { return new LockedApplication(new Application(applicationId), lock); }); - // Determine what we are doing + // Determine Vespa version to use Version version; - if (options.deployCurrentVersion) + if (options.deployCurrentVersion) { version = application.versionIn(zone, controller); - else if (canDeployDirectlyTo(zone, options)) + } else if (canDeployDirectlyTo(zone, options)) { version = options.vespaVersion.map(Version::new).orElse(controller.systemVersion()); - else if ( ! application.deploying().isPresent() && ! zone.environment().isManuallyDeployed()) - return unexpectedDeployment(applicationId, zone, applicationPackage); - else + } else if (! application.change().isPresent() && ! zone.environment().isManuallyDeployed()) { + return unexpectedDeployment(applicationId, zone, applicationPackageFromDeployer); + } else { version = application.deployVersionIn(zone, controller); + } Optional<DeploymentJobs.JobType> jobType = DeploymentJobs.JobType.from(controller.system(), zone); - ApplicationRevision revision = toApplicationPackageRevision(applicationPackage, options.screwdriverBuildJob); + if (!jobType.isPresent() && !applicationPackageFromDeployer.isPresent()) { + throw new IllegalArgumentException("Unable to determine job type from zone '" + zone + + "' and no application package was given"); + } - if ( ! options.deployCurrentVersion) { - // Add missing information to application (unless we're deploying the previous version (initial staging step) - application = application.with(applicationPackage.deploymentSpec()); - application = application.with(applicationPackage.validationOverrides()); - if (options.screwdriverBuildJob.isPresent() && options.screwdriverBuildJob.get().screwdriverId != null) - application = application.withProjectId(options.screwdriverBuildJob.get().screwdriverId.value()); - if (application.deploying().isPresent() && application.deploying().get() instanceof Change.ApplicationChange) - application = application.withDeploying(Optional.of(Change.ApplicationChange.of(revision))); - if ( ! canDeployDirectlyTo(zone, options) && jobType.isPresent()) { + // Determine which application package to use + ApplicationPackage applicationPackage; + ApplicationVersion applicationVersion; + if (applicationPackageFromDeployer.isPresent()) { + applicationVersion = toApplicationPackageRevision(applicationPackageFromDeployer.get(), + options.screwdriverBuildJob); + applicationPackage = applicationPackageFromDeployer.get(); + } else { + applicationVersion = application.deployApplicationVersion(jobType.get(), controller) + .orElseThrow(() -> new IllegalArgumentException("Cannot determine application version to use in " + zone)); + applicationPackage = new ApplicationPackage(artifactRepository.getApplicationPackage( + applicationId, applicationVersion.id()) + ); + } + + validate(applicationPackage.deploymentSpec()); + + // TODO: Remove after introducing new application version number + if ( ! options.deployCurrentVersion && applicationPackageFromDeployer.isPresent()) { + if (application.change().application().isPresent()) { + application = application.withDeploying(application.change().with(applicationVersion)); + } + if (!canDeployDirectlyTo(zone, options) && jobType.isPresent()) { // Update with (potentially) missing information about what we triggered: // * When someone else triggered the job, we need to store a stand-in triggering event. - // * When this is the system test job, we need to record the new revision, for future use. + // * When this is the system test job, we need to record the new application version, + // for future use. JobStatus.JobRun triggering = getOrCreateTriggering(application, version, jobType.get()); application = application.withJobTriggering(jobType.get(), - application.deploying(), + application.change(), triggering.at(), version, - Optional.of(revision), + applicationVersion, triggering.reason()); } + } + + // Update application with information from application package + if (!options.deployCurrentVersion) { + // Store information about application package + application = application.with(applicationPackage.deploymentSpec()); + application = application.with(applicationPackage.validationOverrides()); // Delete zones not listed in DeploymentSpec, if allowed // We do this at deployment time to be able to return a validation failure message when necessary @@ -326,15 +356,19 @@ public class ApplicationController { store(application); // store missing information even if we fail deployment below } - if ( ! canDeployDirectlyTo(zone, options)) { // validate automated deployment - if ( ! application.deploymentJobs().isDeployableTo(zone.environment(), application.deploying())) + // Validate automated deployment + if (!canDeployDirectlyTo(zone, options)) { + if (!application.deploymentJobs().isDeployableTo(zone.environment(), application.change())) { throw new IllegalArgumentException("Rejecting deployment of " + application + " to " + zone + - " as " + application.deploying().get() + " is not tested"); + " as " + application.change() + " is not tested"); + } Deployment existingDeployment = application.deployments().get(zone); - if (zone.environment().isProduction() && existingDeployment != null && existingDeployment.version().isAfter(version)) + if (zone.environment().isProduction() && existingDeployment != null && + existingDeployment.version().isAfter(version)) { throw new IllegalArgumentException("Rejecting deployment of " + application + " to " + zone + " as the requested version " + version + " is older than" + " the current version " + existingDeployment.version()); + } } application = withRotation(application, zone); @@ -353,11 +387,12 @@ public class ApplicationController { configserverClient.prepare(new DeploymentId(applicationId, zone), options, cnames, rotationNames, applicationPackage.zippedContent()); preparedApplication.activate(); - application = application.withNewDeployment(zone, revision, version, clock.instant()); + application = application.withNewDeployment(zone, applicationVersion, version, clock.instant()); store(application); - return new ActivateResult(new RevisionId(applicationPackage.hash()), preparedApplication.prepareResponse()); + return new ActivateResult(new RevisionId(applicationPackage.hash()), preparedApplication.prepareResponse(), + applicationPackage.zippedContent().length); } } @@ -376,7 +411,8 @@ public class ApplicationController { return application; } - private ActivateResult unexpectedDeployment(ApplicationId applicationId, ZoneId zone, ApplicationPackage applicationPackage) { + private ActivateResult unexpectedDeployment(ApplicationId applicationId, ZoneId zone, + Optional<ApplicationPackage> applicationPackage) { Log logEntry = new Log(); logEntry.level = "WARNING"; logEntry.time = clock.instant().toEpochMilli(); @@ -384,7 +420,9 @@ public class ApplicationController { PrepareResponse prepareResponse = new PrepareResponse(); prepareResponse.log = Collections.singletonList(logEntry); prepareResponse.configChangeActions = new ConfigChangeActions(Collections.emptyList(), Collections.emptyList()); - return new ActivateResult(new RevisionId(applicationPackage.hash()), prepareResponse); + return new ActivateResult(new RevisionId(applicationPackage.map(ApplicationPackage::hash) + .orElse("0")), prepareResponse, + applicationPackage.map(a -> a.zippedContent().length).orElse(0)); } private LockedApplication deleteRemovedDeployments(LockedApplication application) { @@ -435,7 +473,7 @@ public class ApplicationController { } private JobStatus.JobRun incompleteTriggeringEvent(Version version) { - return new JobStatus.JobRun(-1, version, Optional.empty(), false, "", clock.instant()); + return new JobStatus.JobRun(-1, version, ApplicationVersion.unknown, false, "", clock.instant()); } private DeployOptions withVersion(Version version, DeployOptions options) { @@ -445,18 +483,18 @@ public class ApplicationController { options.deployCurrentVersion); } - private ApplicationRevision toApplicationPackageRevision(ApplicationPackage applicationPackage, - Optional<ScrewdriverBuildJob> screwDriverBuildJob) { - if ( ! screwDriverBuildJob.isPresent()) - return ApplicationRevision.from(applicationPackage.hash()); + private ApplicationVersion toApplicationPackageRevision(ApplicationPackage applicationPackage, + Optional<ScrewdriverBuildJob> buildJob) { + if ( ! buildJob.isPresent()) + return ApplicationVersion.from(applicationPackage.hash()); - GitRevision gitRevision = screwDriverBuildJob.get().gitRevision; + GitRevision gitRevision = buildJob.get().gitRevision; if (gitRevision.repository == null || gitRevision.branch == null || gitRevision.commit == null) - return ApplicationRevision.from(applicationPackage.hash()); + return ApplicationVersion.from(applicationPackage.hash()); - return ApplicationRevision.from(applicationPackage.hash(), new SourceRevision(gitRevision.repository.id(), - gitRevision.branch.id(), - gitRevision.commit.id())); + return ApplicationVersion.from(applicationPackage.hash(), new SourceRevision(gitRevision.repository.id(), + gitRevision.branch.id(), + gitRevision.commit.id())); } /** Register a DNS name for rotation */ @@ -661,7 +699,7 @@ public class ApplicationController { } /** Verify that each of the production zones listed in the deployment spec exist in this system. */ - public void validate(DeploymentSpec deploymentSpec) { + private void validate(DeploymentSpec deploymentSpec) { deploymentSpec.zones().stream() .filter(zone -> zone.environment() == Environment.prod) .forEach(zone -> { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java index 0e13f4181c4..0ec00f61311 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java @@ -8,8 +8,6 @@ import com.yahoo.component.Version; import com.yahoo.component.Vtag; import com.yahoo.config.provision.SystemName; import com.yahoo.vespa.athenz.api.AthenzDomain; -import com.yahoo.vespa.hosted.controller.api.integration.noderepository.NodeRepositoryClientInterface; -import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.api.identifiers.DeploymentId; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; @@ -17,13 +15,16 @@ import com.yahoo.vespa.hosted.controller.api.integration.MetricsService; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory; import com.yahoo.vespa.hosted.controller.api.integration.chef.Chef; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServerClient; +import com.yahoo.vespa.hosted.controller.api.integration.deployment.ArtifactRepository; import com.yahoo.vespa.hosted.controller.api.integration.dns.NameService; import com.yahoo.vespa.hosted.controller.api.integration.entity.EntityService; import com.yahoo.vespa.hosted.controller.api.integration.github.GitHub; +import com.yahoo.vespa.hosted.controller.api.integration.noderepository.NodeRepositoryClientInterface; import com.yahoo.vespa.hosted.controller.api.integration.organization.Organization; import com.yahoo.vespa.hosted.controller.api.integration.routing.GlobalRoutingService; import com.yahoo.vespa.hosted.controller.api.integration.routing.RotationStatus; import com.yahoo.vespa.hosted.controller.api.integration.routing.RoutingGenerator; +import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneRegistry; import com.yahoo.vespa.hosted.controller.persistence.ControllerDb; import com.yahoo.vespa.hosted.controller.persistence.CuratorDb; @@ -85,11 +86,12 @@ public class Controller extends AbstractComponent { GlobalRoutingService globalRoutingService, ZoneRegistry zoneRegistry, ConfigServerClient configServerClient, NodeRepositoryClientInterface nodeRepositoryClient, MetricsService metricsService, NameService nameService, - RoutingGenerator routingGenerator, Chef chefClient, AthenzClientFactory athenzClientFactory) { + RoutingGenerator routingGenerator, Chef chefClient, AthenzClientFactory athenzClientFactory, + ArtifactRepository artifactRepository) { this(db, curator, rotationsConfig, gitHub, entityService, organization, globalRoutingService, zoneRegistry, configServerClient, nodeRepositoryClient, metricsService, nameService, routingGenerator, chefClient, - Clock.systemUTC(), athenzClientFactory); + Clock.systemUTC(), athenzClientFactory, artifactRepository); } public Controller(ControllerDb db, CuratorDb curator, RotationsConfig rotationsConfig, @@ -98,7 +100,7 @@ public class Controller extends AbstractComponent { ZoneRegistry zoneRegistry, ConfigServerClient configServerClient, NodeRepositoryClientInterface nodeRepositoryClient, MetricsService metricsService, NameService nameService, RoutingGenerator routingGenerator, Chef chefClient, Clock clock, - AthenzClientFactory athenzClientFactory) { + AthenzClientFactory athenzClientFactory, ArtifactRepository artifactRepository) { Objects.requireNonNull(db, "Controller db cannot be null"); Objects.requireNonNull(curator, "Curator cannot be null"); Objects.requireNonNull(rotationsConfig, "RotationsConfig cannot be null"); @@ -115,6 +117,7 @@ public class Controller extends AbstractComponent { Objects.requireNonNull(chefClient, "ChefClient cannot be null"); Objects.requireNonNull(clock, "Clock cannot be null"); Objects.requireNonNull(athenzClientFactory, "Athens cannot be null"); + Objects.requireNonNull(artifactRepository, "ArtifactRepository cannot be null"); this.curator = curator; this.gitHub = gitHub; @@ -131,7 +134,8 @@ public class Controller extends AbstractComponent { applicationController = new ApplicationController(this, db, curator, athenzClientFactory, rotationsConfig, - nameService, configServerClient, routingGenerator, clock); + nameService, configServerClient, artifactRepository, + routingGenerator, clock); tenantController = new TenantController(this, db, curator, entityService, athenzClientFactory); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java index 5fa5b8c318b..e744df0da68 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller; import com.yahoo.component.Version; @@ -11,7 +12,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.MetricsService; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService.ApplicationMetrics; import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId; import com.yahoo.vespa.hosted.controller.application.ApplicationRotation; -import com.yahoo.vespa.hosted.controller.application.ApplicationRevision; +import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.ClusterInfo; import com.yahoo.vespa.hosted.controller.application.ClusterUtilization; @@ -41,7 +42,7 @@ public class LockedApplication extends Application { * @param application The application to lock. * @param lock The lock for the application. */ - LockedApplication(Application application, Lock lock) { + LockedApplication(Application application, @SuppressWarnings("unused") Lock lock) { this(new Builder(application)); } @@ -63,15 +64,18 @@ public class LockedApplication extends Application { return new LockedApplication(new Builder(this).with(deploymentJobs().withCompletion(report, notificationTime, controller))); } - public LockedApplication withJobTriggering(JobType type, Optional<Change> change, Instant triggerTime, - Version version, Optional<ApplicationRevision> revision, String reason) { - return new LockedApplication(new Builder(this).with(deploymentJobs().withTriggering(type, change, version, revision, reason, triggerTime))); + public LockedApplication withJobTriggering(JobType type, Change change, Instant triggerTime, + Version version, ApplicationVersion applicationVersion, + String reason) { + return new LockedApplication(new Builder(this).with(deploymentJobs().withTriggering(type, change, version, applicationVersion, reason, triggerTime))); } - public LockedApplication withNewDeployment(ZoneId zone, ApplicationRevision revision, Version version, Instant instant) { + public LockedApplication withNewDeployment(ZoneId zone, ApplicationVersion applicationVersion, Version version, + Instant instant) { // Use info from previous deployment if available, otherwise create a new one. - Deployment previousDeployment = deployments().getOrDefault(zone, new Deployment(zone, revision, version, instant)); - Deployment newDeployment = new Deployment(zone, revision, version, instant, + Deployment previousDeployment = deployments().getOrDefault(zone, new Deployment(zone, applicationVersion, + version, instant)); + Deployment newDeployment = new Deployment(zone, applicationVersion, version, instant, previousDeployment.clusterUtils(), previousDeployment.clusterInfo(), previousDeployment.metrics()); @@ -115,7 +119,7 @@ public class LockedApplication extends Application { return new LockedApplication(new Builder(this).with(validationOverrides)); } - public LockedApplication withDeploying(Optional<Change> deploying) { + public LockedApplication withDeploying(Change deploying) { return new LockedApplication(new Builder(this).withDeploying(deploying)); } @@ -141,10 +145,10 @@ public class LockedApplication extends Application { : deployVersionIn(jobType.zone(controller.system()).get(), controller); } - public Optional<ApplicationRevision> deployRevisionFor(DeploymentJobs.JobType jobType, Controller controller) { + public Optional<ApplicationVersion> deployApplicationVersion(DeploymentJobs.JobType jobType, Controller controller) { return jobType == JobType.component - ? Optional.empty() - : deployRevisionIn(jobType.zone(controller.system()).get()); + ? Optional.empty() + : deployApplicationVersionIn(jobType.zone(controller.system()).get()); } /** Don't expose non-leaf sub-objects. */ @@ -162,7 +166,7 @@ public class LockedApplication extends Application { private ValidationOverrides validationOverrides; private Map<ZoneId, Deployment> deployments; private DeploymentJobs deploymentJobs; - private Optional<Change> deploying; + private Change deploying; private boolean hasOutstandingChange; private Optional<IssueId> ownershipIssueId; private ApplicationMetrics metrics; @@ -174,7 +178,7 @@ public class LockedApplication extends Application { this.validationOverrides = application.validationOverrides(); this.deployments = application.deployments(); this.deploymentJobs = application.deploymentJobs(); - this.deploying = application.deploying(); + this.deploying = application.change(); this.hasOutstandingChange = application.hasOutstandingChange(); this.ownershipIssueId = application.ownershipIssueId(); this.metrics = application.metrics(); @@ -201,7 +205,7 @@ public class LockedApplication extends Application { return this; } - private Builder withDeploying(Optional<Change> deploying) { + private Builder withDeploying(Change deploying) { this.deploying = deploying; return this; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java index 6b1c5d56a5f..271942ff9a3 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java @@ -2,11 +2,8 @@ package com.yahoo.vespa.hosted.controller.api; import com.yahoo.vespa.hosted.controller.api.identifiers.RevisionId; -import com.yahoo.vespa.hosted.controller.api.integration.configserver.Log; import com.yahoo.vespa.hosted.controller.api.integration.configserver.PrepareResponse; -import java.util.List; - /** * @author Oyvind Gronnesby */ @@ -14,17 +11,23 @@ public class ActivateResult { private final RevisionId revisionId; private final PrepareResponse prepareResponse; + private final long applicationZipSizeBytes; - public ActivateResult(RevisionId revisionId, PrepareResponse prepareResponse) { + public ActivateResult(RevisionId revisionId, PrepareResponse prepareResponse, long applicationZipSizeBytes) { this.revisionId = revisionId; this.prepareResponse = prepareResponse; + this.applicationZipSizeBytes = applicationZipSizeBytes; + } + + public long applicationZipSizeBytes() { + return applicationZipSizeBytes; } - public RevisionId getRevisionId() { + public RevisionId revisionId() { return revisionId; } - public PrepareResponse getPrepareResponse() { + public PrepareResponse prepareResponse() { return prepareResponse; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationList.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationList.java index 283d6a75178..777cb4011b6 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationList.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationList.java @@ -17,7 +17,7 @@ import java.util.stream.Stream; /** * A list of applications which can be filtered in various ways. - * + * * @author bratseth */ public class ApplicationList { @@ -27,9 +27,9 @@ public class ApplicationList { private ApplicationList(Iterable<Application> applications) { this.list = ImmutableList.copyOf(applications); } - + // ----------------------------------- Factories - + public static ApplicationList from(Iterable<Application> applications) { return new ApplicationList(applications); } @@ -54,7 +54,7 @@ public class ApplicationList { /** Returns the subset of applications which are currently upgrading (to any version) */ public ApplicationList upgrading() { - return listOf(list.stream().filter(application -> isUpgrading(application))); + return listOf(list.stream().filter(application -> application.change().platform().isPresent())); } /** Returns the subset of applications which are currently upgrading to the given version */ @@ -62,17 +62,12 @@ public class ApplicationList { return listOf(list.stream().filter(application -> isUpgradingTo(version, application))); } - /** Returns the subset of applications which are currently upgrading to a version lower than the given version */ - public ApplicationList upgradingToLowerThan(Version version) { - return listOf(list.stream().filter(application -> isUpgradingToLowerThan(version, application))); - } - /** Returns the subset of applications which are currently not upgrading to the given version */ public ApplicationList notUpgradingTo(Version version) { return listOf(list.stream().filter(application -> ! isUpgradingTo(version, application))); } - /** + /** * Returns the subset of applications which are currently not upgrading to the given version, * or returns all if no version is specified */ @@ -81,14 +76,9 @@ public class ApplicationList { return notUpgradingTo(version.get()); } - /** Returns the subset of applications which is currently not deploying a new application revision */ - public ApplicationList notDeployingApplication() { - return listOf(list.stream().filter(application -> ! isDeployingApplicationChange(application))); - } - /** Returns the subset of applications which is currently not deploying a change */ public ApplicationList notDeploying() { - return listOf(list.stream().filter(application -> ! application.deploying().isPresent())); + return listOf(list.stream().filter(application -> ! application.change().isPresent())); } /** Returns the subset of applications which currently does not have any failing jobs */ @@ -135,7 +125,7 @@ public class ApplicationList { public ApplicationList without(UpgradePolicy policy) { return listOf(list.stream().filter(a -> a.deploymentSpec().upgradePolicy() != policy)); } - + /** Returns the subset of applications which have at least one deployment on a lower version than the given one */ public ApplicationList onLowerVersionThan(Version version) { return listOf(list.stream() @@ -144,7 +134,7 @@ public class ApplicationList { } /** - * Returns the subset of applications which are not pull requests: + * Returns the subset of applications which are not pull requests: * Pull requests changes the application instance name to (default-pr)?[pull-request-number] */ public ApplicationList notPullRequest() { @@ -178,34 +168,10 @@ public class ApplicationList { return listOf(list.stream().sorted(Comparator.comparing(application -> application.oldestDeployedVersion().orElse(Version.emptyVersion)))); } - /** Returns the subset of applications that are not currently upgrading */ - public ApplicationList notCurrentlyUpgrading(Change.VersionChange change, Instant jobTimeoutLimit) { - return listOf(list.stream().filter(a -> ! currentlyUpgrading(change, a, jobTimeoutLimit))); - } - // ----------------------------------- Internal helpers - private static boolean isUpgrading(Application application) { - if ( ! (application.deploying().isPresent()) ) return false; - if ( ! (application.deploying().get() instanceof Change.VersionChange) ) return false; - return true; - } - private static boolean isUpgradingTo(Version version, Application application) { - if ( ! (application.deploying().isPresent()) ) return false; - if ( ! (application.deploying().get() instanceof Change.VersionChange) ) return false; - return ((Change.VersionChange)application.deploying().get()).version().equals(version); - } - - private static boolean isUpgradingToLowerThan(Version version, Application application) { - if ( ! application.deploying().isPresent()) return false; - if ( ! (application.deploying().get() instanceof Change.VersionChange) ) return false; - return ((Change.VersionChange)application.deploying().get()).version().isBefore(version); - } - - private static boolean isDeployingApplicationChange(Application application) { - if ( ! application.deploying().isPresent()) return false; - return application.deploying().get() instanceof Change.ApplicationChange; + return application.change().platform().equals(Optional.of(version)); } private static boolean failingOn(Version version, Application application) { @@ -215,13 +181,6 @@ public class ApplicationList { .isEmpty(); } - private static boolean currentlyUpgrading(Change.VersionChange change, Application application, Instant jobTimeoutLimit) { - return ! JobList.from(application) - .running(jobTimeoutLimit) - .lastTriggered().on(change.version()) - .isEmpty(); - } - private static boolean failingUpgradeToVersionSince(Application application, Version version, Instant threshold) { return ! JobList.from(application) .not().failingApplicationChange() diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationRevision.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationRevision.java deleted file mode 100644 index 1b875f28715..00000000000 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationRevision.java +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.application; - -import java.util.Objects; -import java.util.Optional; - -/** - * An identifier of a particular revision (exact content) of an application package, - * optionally with information about the source of the package revision. - * - * @author bratseth - */ -public class ApplicationRevision { - - private final String applicationPackageHash; - - private final Optional<SourceRevision> source; - - private ApplicationRevision(String applicationPackageHash, Optional<SourceRevision> source) { - Objects.requireNonNull(applicationPackageHash, "applicationPackageHash cannot be null"); - this.applicationPackageHash = applicationPackageHash; - this.source = source; - } - - /** Create an application package revision where there is no information about its source */ - public static ApplicationRevision from(String applicationPackageHash) { - return new ApplicationRevision(applicationPackageHash, Optional.empty()); - } - - /** Create an application package revision with a source */ - public static ApplicationRevision from(String applicationPackageHash, SourceRevision source) { - return new ApplicationRevision(applicationPackageHash, Optional.of(source)); - } - - /** Returns a unique, content-based identifier of an application package (a hash of the content) */ - public String id() { return applicationPackageHash; } - - /** - * Returns information about the source of this revision, or empty if the source is not know/defined - * (which is the case for command-line deployment from developers, but never for deployment jobs) - */ - public Optional<SourceRevision> source() { return source; } - - @Override - public int hashCode() { return applicationPackageHash.hashCode(); } - - @Override - public boolean equals(Object other) { - if (this == other) return true; - if ( ! (other instanceof ApplicationRevision)) return false; - return this.applicationPackageHash.equals(((ApplicationRevision)other).applicationPackageHash); - } - - @Override - public String toString() { - return "Application package revision '" + applicationPackageHash + "'" + - (source.isPresent() ? " with " + source.get() : ""); - } - -} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java new file mode 100644 index 00000000000..304d82b2bec --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationVersion.java @@ -0,0 +1,107 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.application; + +import java.util.Objects; +import java.util.Optional; + +/** + * An application package version. This represents an build artifact, identified by a source revision and a build + * number. + * + * @author bratseth + * @author mpolden + */ +public class ApplicationVersion { + + // TODO: Remove the need for this + public static final ApplicationVersion unknown = new ApplicationVersion(); + + // Never changes. Only used to create a valid version number for the bundle + private static final String majorVersion = "1.0"; + + // TODO: Remove after introducing new application version + private final Optional<String> applicationPackageHash; + + // TODO: Make mandatory + private final Optional<SourceRevision> source; + private final Optional<Long> buildNumber; + + private ApplicationVersion() { + this.applicationPackageHash = Optional.empty(); + this.source = Optional.empty(); + this.buildNumber = Optional.empty(); + } + + private ApplicationVersion(Optional<String> applicationPackageHash, Optional<SourceRevision> source, + Optional<Long> buildNumber) { + Objects.requireNonNull(applicationPackageHash, "applicationPackageHash cannot be null"); + Objects.requireNonNull(source, "source cannot be null"); + Objects.requireNonNull(buildNumber, "buildNumber cannot be null"); + if (buildNumber.isPresent() && !source.isPresent()) { + throw new IllegalArgumentException("both buildNumber and source must be set if buildNumber is set"); + } + if ( ! buildNumber.isPresent() && ! applicationPackageHash.isPresent()) { + throw new IllegalArgumentException("applicationPackageHash must be given if buildNumber is unset"); + } + this.applicationPackageHash = applicationPackageHash; + this.source = source; + this.buildNumber = buildNumber; + } + + /** Create an application package revision where there is no information about its source */ + public static ApplicationVersion from(String applicationPackageHash) { + return new ApplicationVersion(Optional.of(applicationPackageHash), Optional.empty(), Optional.empty()); + } + + /** Create an application package revision with a source */ + public static ApplicationVersion from(String applicationPackageHash, SourceRevision source) { + return new ApplicationVersion(Optional.of(applicationPackageHash), Optional.of(source), Optional.empty()); + } + + /** Create an application package version from a completed build */ + public static ApplicationVersion from(SourceRevision source, long buildNumber) { + return new ApplicationVersion(Optional.empty(), Optional.of(source), Optional.of(buildNumber)); + } + + /** Returns an unique identifier for this version */ + public String id() { + if (applicationPackageHash.isPresent()) { + return applicationPackageHash.get(); + } + return String.format("%s.%d-%s", majorVersion, buildNumber.get(), abbreviateCommit(source.get().commit())); + } + + /** + * Returns information about the source of this revision, or empty if the source is not know/defined + * (which is the case for command-line deployment from developers, but never for deployment jobs) + */ + public Optional<SourceRevision> source() { return source; } + + /** Returns the build number that built this version */ + public Optional<Long> buildNumber() { return buildNumber; } + + @Override + public int hashCode() { return applicationPackageHash.hashCode(); } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if ( ! (other instanceof ApplicationVersion)) return false; + return this.applicationPackageHash.equals(((ApplicationVersion)other).applicationPackageHash); + } + + @Override + public String toString() { + if (buildNumber.isPresent()) { + return "Application package version: " + abbreviateCommit(source.get().commit()) + "-" + buildNumber.get(); + } + return "Application package revision '" + applicationPackageHash + "'" + + (source.isPresent() ? " with " + source.get() : ""); + } + + /** Abbreviate given commit hash to 9 characters */ + private static String abbreviateCommit(String hash) { + return hash.length() <= 9 ? hash : hash.substring(0, 9); + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Change.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Change.java index d9c22018d26..13d66c8d083 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Change.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Change.java @@ -9,97 +9,90 @@ import java.util.Objects; import java.util.Optional; /** - * A change to an application - * + * The changes to an application we currently wish to complete deploying. + * A goal of the system is to deploy platform and application versions separately. + * However, this goal must some times be traded against others, so a change can + * consist of both an application and platform version change. + * + * This is immutable. + * * @author bratseth */ -public abstract class Change { +public final class Change { + + private static final Change empty = new Change(Optional.empty(), Optional.empty()); + + /** The platform version we are upgrading to, or empty if none */ + private final Optional<Version> platform; + + /** The application version we are changing to, or empty if none */ + private final Optional<ApplicationVersion> application; + + private Change(Optional<Version> platform, Optional<ApplicationVersion> application) { + Objects.requireNonNull(platform, "platform cannot be null"); + Objects.requireNonNull(application, "application cannot be null"); + this.platform = platform; + this.application = application; + } /** Returns true if this change is blocked by the given spec at the given instant */ - public abstract boolean blockedBy(DeploymentSpec deploymentSpec, Instant instant); - - /** A change to the application package revision of an application */ - public static class ApplicationChange extends Change { - - private final Optional<ApplicationRevision> revision; - - private ApplicationChange(Optional<ApplicationRevision> revision) { - Objects.requireNonNull(revision, "revision cannot be null"); - this.revision = revision; - } - - /** The revision this changes to, or empty if not known yet */ - public Optional<ApplicationRevision> revision() { return revision; } - - @Override - public boolean blockedBy(DeploymentSpec deploymentSpec, Instant instant) { - return ! deploymentSpec.canChangeRevisionAt(instant); - } - - @Override - public int hashCode() { return revision.hashCode(); } - - @Override - public boolean equals(Object other) { - if (this == other) return true; - if ( ! (other instanceof ApplicationChange)) return false; - return ((ApplicationChange)other).revision.equals(this.revision); - } - - /** - * Creates an application change which we don't know anything about. - * We are notified that a change has occurred by completion of the component job - * but do not get to know about what the change is until a subsequent deployment - * happens. - */ - public static ApplicationChange unknown() { - return new ApplicationChange(Optional.empty()); - } - - public static ApplicationChange of(ApplicationRevision revision) { - return new ApplicationChange(Optional.of(revision)); - } - - @Override - public String toString() { - return "application change to " + revision.map(ApplicationRevision::toString).orElse("an unknown revision"); - } - + public boolean blockedBy(DeploymentSpec deploymentSpec, Instant instant) { + if (platform.isPresent() && ! deploymentSpec.canUpgradeAt(instant)) return true; + if (application.isPresent() && ! deploymentSpec.canChangeRevisionAt(instant)) return true; + return false; } - /** A change to the Vespa version running an application */ - public static class VersionChange extends Change { + /** Returns whether a change shoudl currently be deployed */ + public boolean isPresent() { + return platform.isPresent() || application.isPresent(); + } - private final Version version; + /** Returns the platform version change which should currently be deployed, if any */ + public Optional<Version> platform() { return platform; } - public VersionChange(Version version) { - Objects.requireNonNull(version, "version cannot be null"); - this.version = version; - } + /** Returns the application version change which should currently be deployed, if any */ + public Optional<ApplicationVersion> application() { return application; } - /** The Vespa version this changes to */ - public Version version() { return version; } + /** Returns an instance representing no change */ + public static Change empty() { return empty; } - @Override - public boolean blockedBy(DeploymentSpec deploymentSpec, Instant instant) { - return ! deploymentSpec.canUpgradeAt(instant); - } + /** Returns a version of this change which replaces or adds this application change */ + public Change with(ApplicationVersion applicationVersion) { + return new Change(platform, Optional.of(applicationVersion)); + } - @Override - public int hashCode() { return version.hashCode(); } + @Override + public int hashCode() { return Objects.hash(platform, application); } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + if ( ! (other instanceof Change)) return false; + Change o = (Change)other; + if ( ! o.platform.equals(this.platform)) return false; + if ( ! o.application.equals(this.application)) return false; + return true; + } - @Override - public boolean equals(Object other) { - if (this == other) return true; - if ( ! (other instanceof VersionChange)) return false; - return ((VersionChange)other).version.equals(this.version); - } + @Override + public String toString() { + String platformString = platform.map(v -> "upgrade to " + v).orElse(null); + String applicationString = application.map(v -> "application change to " + v).orElse(null); + if (platformString != null && applicationString != null) + return platformString + " and " + applicationString; + if (platformString != null) + return platformString; + if (applicationString != null) + return applicationString; + return "no change"; + } - @Override - public String toString() { - return "version change to " + version; - } + public static Change of(ApplicationVersion applicationVersion) { + return new Change(Optional.empty(), Optional.of(applicationVersion)); + } + public static Change of(Version platformChange) { + return new Change(Optional.of(platformChange), Optional.empty()); } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java index 2364e87b345..8fa0c6da49c 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java @@ -19,28 +19,28 @@ import java.util.Objects; public class Deployment { private final ZoneId zone; - private final ApplicationRevision revision; + private final ApplicationVersion applicationVersion; private final Version version; private final Instant deployTime; private final Map<Id, ClusterUtilization> clusterUtils; private final Map<Id, ClusterInfo> clusterInfo; private final DeploymentMetrics metrics; - public Deployment(ZoneId zone, ApplicationRevision revision, Version version, Instant deployTime) { - this(zone, revision, version, deployTime, new HashMap<>(), new HashMap<>(), new DeploymentMetrics()); + public Deployment(ZoneId zone, ApplicationVersion applicationVersion, Version version, Instant deployTime) { + this(zone, applicationVersion, version, deployTime, new HashMap<>(), new HashMap<>(), new DeploymentMetrics()); } - public Deployment(ZoneId zone, ApplicationRevision revision, Version version, Instant deployTime, + public Deployment(ZoneId zone, ApplicationVersion applicationVersion, Version version, Instant deployTime, Map<Id, ClusterUtilization> clusterUtils, Map<Id, ClusterInfo> clusterInfo, DeploymentMetrics metrics) { Objects.requireNonNull(zone, "zone cannot be null"); - Objects.requireNonNull(revision, "revision cannot be null"); + Objects.requireNonNull(applicationVersion, "applicationVersion cannot be null"); Objects.requireNonNull(version, "version cannot be null"); Objects.requireNonNull(deployTime, "deployTime cannot be null"); Objects.requireNonNull(clusterUtils, "clusterUtils cannot be null"); Objects.requireNonNull(clusterInfo, "clusterInfo cannot be null"); Objects.requireNonNull(metrics, "deployment metrics cannot be null"); this.zone = zone; - this.revision = revision; + this.applicationVersion = applicationVersion; this.version = version; this.deployTime = deployTime; this.clusterUtils = clusterUtils; @@ -51,10 +51,10 @@ public class Deployment { /** Returns the zone this was deployed to */ public ZoneId zone() { return zone; } - /** Returns the revision of the application which was deployed */ - public ApplicationRevision revision() { return revision; } + /** Returns the deployed application version */ + public ApplicationVersion applicationVersion() { return applicationVersion; } - /** Returns the Vespa version which was deployed */ + /** Returns the deployed Vespa version */ public Version version() { return version; } /** Returns the time this was deployed */ @@ -69,15 +69,15 @@ public class Deployment { } public Deployment withClusterUtils(Map<Id, ClusterUtilization> clusterUtilization) { - return new Deployment(zone, revision, version, deployTime, clusterUtilization, clusterInfo, metrics); + return new Deployment(zone, applicationVersion, version, deployTime, clusterUtilization, clusterInfo, metrics); } public Deployment withClusterInfo(Map<Id, ClusterInfo> newClusterInfo) { - return new Deployment(zone, revision, version, deployTime, clusterUtils, newClusterInfo, metrics); + return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, newClusterInfo, metrics); } public Deployment withMetrics(DeploymentMetrics metrics) { - return new Deployment(zone, revision, version, deployTime, clusterUtils, clusterInfo, metrics); + return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, clusterInfo, metrics); } /** @return Key metrics for the deployment (application level) like QPS and document count */ @@ -107,6 +107,6 @@ public class Deployment { @Override public String toString() { - return "deployment to " + zone + " of " + revision + " on version " + version + " at " + deployTime; + return "deployment to " + zone + " of " + applicationVersion + " on version " + version + " at " + deployTime; } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentJobs.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentJobs.java index a728786e2ce..bb7b39eed0f 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentJobs.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentJobs.java @@ -65,17 +65,17 @@ public class DeploymentJobs { } public DeploymentJobs withTriggering(JobType jobType, - Optional<Change> change, + Change change, Version version, - Optional<ApplicationRevision> revision, + ApplicationVersion applicationVersion, String reason, Instant triggerTime) { Map<JobType, JobStatus> status = new LinkedHashMap<>(this.status); status.compute(jobType, (type, job) -> { if (job == null) job = JobStatus.initial(jobType); - return job.withTriggering( version, - revision, - change.isPresent() && change.get() instanceof Change.VersionChange, + return job.withTriggering(version, + applicationVersion, + change.platform().isPresent(), reason, triggerTime); }); @@ -117,14 +117,14 @@ public class DeploymentJobs { } /** Returns whether change can be deployed to the given environment */ - public boolean isDeployableTo(Environment environment, Optional<Change> change) { + public boolean isDeployableTo(Environment environment, Change change) { if (environment == null || ! change.isPresent()) { return true; } if (environment == Environment.staging) { - return isSuccessful(change.get(), JobType.systemTest); + return isSuccessful(change, JobType.systemTest); } else if (environment == Environment.prod) { - return isSuccessful(change.get(), JobType.stagingTest); + return isSuccessful(change, JobType.stagingTest); } return true; // other environments do not have any preconditions } @@ -274,7 +274,7 @@ public class DeploymentJobs { public enum JobError { unknown, - outOfCapacity; + outOfCapacity } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobList.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobList.java index 161035b1164..41060a7af4c 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobList.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobList.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.application; import com.google.common.collect.ImmutableList; @@ -12,8 +13,6 @@ import java.util.Optional; import java.util.function.Function; import java.util.function.Predicate; -import static com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobError.outOfCapacity; - /** * A list of deployment jobs that can be filtered in various ways. * @@ -47,10 +46,10 @@ public class JobList { // TODO: Add sorting based on various stuff, such as deployment order, time of last completion, etc.. - /** Returns the jobstatuses in this as an immutable list */ + /** Returns the job statuses in this as an immutable list */ public List<JobStatus> asList() { return list; } - /** Returns the jobstatuses in this as an immutable list after mapping with the given function */ + /** Returns the job statuses in this as an immutable list after mapping with the given function */ public <Type> List<Type> mapToList(Function<JobStatus, Type> mapper) { return ImmutableList.copyOf(list.stream().map(mapper)::iterator); } @@ -67,7 +66,7 @@ public class JobList { } /** Returns the subset of jobs which are current upgrading */ - public JobList upgrading() { // TODO: Centralise and standardise reasoning about upgrades and revisions. + public JobList upgrading() { // TODO: Centralise and standardise reasoning about upgrades and application versions. return filter(job -> job.lastSuccess().isPresent() && job.lastTriggered().isPresent() && ! job.lastTriggered().get().at().isBefore(job.lastCompleted().get().at()) @@ -86,7 +85,7 @@ public class JobList { /** Returns the subset of jobs which must be failing due to an application change */ public JobList failingApplicationChange() { - return filter(job -> failingApplicationChange(job)); + return filter(JobList::failingApplicationChange); } /** Returns the subset of jobs which are failing with the given job error */ @@ -108,22 +107,22 @@ public class JobList { /** Returns the list in a state where the next filter is for the lastTriggered run type */ public JobRunFilter lastTriggered() { - return new JobRunFilter(job -> job.lastTriggered()); + return new JobRunFilter(JobStatus::lastTriggered); } /** Returns the list in a state where the next filter is for the lastCompleted run type */ public JobRunFilter lastCompleted() { - return new JobRunFilter(job -> job.lastCompleted()); + return new JobRunFilter(JobStatus::lastCompleted); } /** Returns the list in a state where the next filter is for the lastSuccess run type */ public JobRunFilter lastSuccess() { - return new JobRunFilter(job -> job.lastSuccess()); + return new JobRunFilter(JobStatus::lastSuccess); } /** Returns the list in a state where the next filter is for the firstFailing run type */ public JobRunFilter firstFailing() { - return new JobRunFilter(job -> job.firstFailing()); + return new JobRunFilter(JobStatus::firstFailing); } @@ -157,7 +156,7 @@ public class JobList { } public JobList upgrade() { - return filter(run -> run.upgrade()); + return filter(JobRun::upgrade); } /** Transforms the JobRun condition to a JobStatus condition, by considering only the JobRun mapped by which, and executes */ @@ -173,9 +172,9 @@ public class JobList { private static boolean failingApplicationChange(JobStatus job) { if ( job.isSuccess()) return false; if ( ! job.lastSuccess().isPresent()) return true; // An application which never succeeded is surely bad. - if ( ! job.lastSuccess().get().revision().isPresent()) return true; // Indicates the component job, which is always an application change. + if ( job.lastSuccess().get().applicationVersion() == ApplicationVersion.unknown) return true; // Indicates the component job, which is always an application change. if ( ! job.firstFailing().get().version().equals(job.lastSuccess().get().version())) return false; // Version change may be to blame. - return ! job.firstFailing().get().revision().equals(job.lastSuccess().get().revision()); // Return whether there is an application change. + return ! job.firstFailing().get().applicationVersion().equals(job.lastSuccess().get().applicationVersion()); // Return whether there is an application change. } /** Returns a new JobList which is the result of filtering with the -- possibly negated -- condition */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java index a7940076277..e165d3c9fe5 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java @@ -55,20 +55,20 @@ public class JobStatus { return new JobStatus(type, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); } - public JobStatus withTriggering(Version version, Optional<ApplicationRevision> revision, + public JobStatus withTriggering(Version version, ApplicationVersion applicationVersion, boolean upgrade, String reason, Instant triggerTime) { - return new JobStatus(type, jobError, Optional.of(new JobRun(-1, version, revision, upgrade, reason, triggerTime)), + return new JobStatus(type, jobError, Optional.of(new JobRun(-1, version, applicationVersion, upgrade, reason, triggerTime)), lastCompleted, firstFailing, lastSuccess); } public JobStatus withCompletion(long runId, Optional<DeploymentJobs.JobError> jobError, Instant completionTime, Controller controller) { Version version; - Optional<ApplicationRevision> revision; + ApplicationVersion applicationVersion; boolean upgrade; String reason; if (type == DeploymentJobs.JobType.component) { // not triggered by us version = controller.systemVersion(); - revision = Optional.empty(); + applicationVersion = ApplicationVersion.unknown; upgrade = false; reason = "Application commit"; } @@ -79,12 +79,12 @@ public class JobStatus { } else { version = lastTriggered.get().version(); - revision = lastTriggered.get().revision(); + applicationVersion = lastTriggered.get().applicationVersion(); upgrade = lastTriggered.get().upgrade(); reason = lastTriggered.get().reason(); } - JobRun thisCompletion = new JobRun(runId, version, revision, upgrade, reason, completionTime); + JobRun thisCompletion = new JobRun(runId, version, applicationVersion, upgrade, reason, completionTime); Optional<JobRun> firstFailing = this.firstFailing; if (jobError.isPresent() && ! this.firstFailing.isPresent()) @@ -167,20 +167,20 @@ public class JobStatus { private final long id; private final Version version; - private final Optional<ApplicationRevision> revision; + private final ApplicationVersion applicationVersion; private final boolean upgrade; private final String reason; private final Instant at; - public JobRun(long id, Version version, Optional<ApplicationRevision> revision, + public JobRun(long id, Version version, ApplicationVersion applicationVersion, boolean upgrade, String reason, Instant at) { Objects.requireNonNull(version, "version cannot be null"); - Objects.requireNonNull(revision, "revision cannot be null"); + Objects.requireNonNull(applicationVersion, "applicationVersion cannot be null"); Objects.requireNonNull(reason, "Reason cannot be null"); Objects.requireNonNull(at, "at cannot be null"); this.id = id; this.version = version; - this.revision = revision; + this.applicationVersion = applicationVersion; this.upgrade = upgrade; this.reason = reason; this.at = at; @@ -197,8 +197,8 @@ public class JobStatus { /** Returns the Vespa version used on this run */ public Version version() { return version; } - /** Returns the application revision used for this run, or empty when not known */ - public Optional<ApplicationRevision> revision() { return revision; } + /** Returns the application version used for this run, or empty when not known */ + public ApplicationVersion applicationVersion() { return applicationVersion; } /** Returns a human-readable reason for this particular job run */ public String reason() { return reason; } @@ -206,22 +206,17 @@ public class JobStatus { /** Returns the time if this triggering or completion */ public Instant at() { return at; } - // TODO: Consider a version and revision for each JobStatus, to compare against a Target (instead of Change, which is, really, a Target). + // TODO: Consider a version and application version for each JobStatus, to compare against a Target (instead of Change, which is, really, a Target). /** Returns whether the job last completed for the given change */ public boolean lastCompletedWas(Change change) { - if (change instanceof Change.ApplicationChange) { - Change.ApplicationChange applicationChange = (Change.ApplicationChange) change; - return revision().equals(applicationChange.revision()); - } else if (change instanceof Change.VersionChange) { - Change.VersionChange versionChange = (Change.VersionChange) change; - return version().equals(versionChange.version()); - } - throw new IllegalArgumentException("Unexpected change: " + change.getClass()); + if (change.platform().isPresent() && ! change.platform().get().equals(version())) return false; + if (change.application().isPresent() && ! change.application().get().equals(applicationVersion)) return false; + return true; } @Override public int hashCode() { - return Objects.hash(version, revision, upgrade, at); + return Objects.hash(version, applicationVersion, upgrade, at); } @Override @@ -231,14 +226,14 @@ public class JobStatus { JobRun jobRun = (JobRun) o; return id == jobRun.id && Objects.equals(version, jobRun.version) && - Objects.equals(revision, jobRun.revision) && + Objects.equals(applicationVersion, jobRun.applicationVersion) && upgrade == jobRun.upgrade && Objects.equals(at, jobRun.at); } @Override public String toString() { return "job run " + id + " of version " + (upgrade() ? "upgrade " : "") + version + " " - + revision + " at " + at; } + + applicationVersion + " at " + at; } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java index d506e8f3dcd..66bb05df308 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java @@ -46,7 +46,7 @@ public class DeploymentOrder { /** Returns a list of jobs to trigger after the given job */ // TODO: This does too much - should just tell us the order, as advertised public List<JobType> nextAfter(JobType job, LockedApplication application) { - if ( ! application.deploying().isPresent()) { // Change was cancelled + if ( ! application.change().isPresent()) { // Change was cancelled return Collections.emptyList(); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java index 7908f9b095a..1beab1307c1 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java @@ -10,8 +10,8 @@ import com.yahoo.vespa.hosted.controller.Controller; import com.yahoo.vespa.hosted.controller.LockedApplication; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationList; +import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; -import com.yahoo.vespa.hosted.controller.application.Change.VersionChange; import com.yahoo.vespa.hosted.controller.application.Deployment; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobError; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobReport; @@ -79,15 +79,19 @@ public class DeploymentTrigger { public void triggerFromCompletion(JobReport report) { applications().lockOrThrow(report.applicationId(), application -> { application = application.withJobCompletion(report, clock.instant(), controller); + application = application.withProjectId(report.projectId()); // Handle successful starting and ending if (report.success()) { if (report.jobType() == JobType.component) { - if (acceptNewRevisionNow(application)) { + if (acceptNewApplicationVersionNow(application)) { // Set this as the change we are doing, unless we are already pushing a platform change - if ( ! ( application.deploying().isPresent() && - (application.deploying().get() instanceof Change.VersionChange))) - application = application.withDeploying(Optional.of(Change.ApplicationChange.unknown())); + if ( ! ( application.change().platform().isPresent())) { + ApplicationVersion applicationVersion = ApplicationVersion.unknown; + if (report.sourceRevision().isPresent()) + applicationVersion = ApplicationVersion.from(report.sourceRevision().get(), report.buildNumber()); + application = application.withDeploying(Change.of(applicationVersion)); + } } else { // postpone applications().store(application.withOutstandingChange(true)); @@ -96,7 +100,7 @@ public class DeploymentTrigger { } else if (deploymentComplete(application)) { // change completed - application = application.withDeploying(Optional.empty()); + application = application.withDeploying(Change.empty()); } } @@ -117,8 +121,8 @@ public class DeploymentTrigger { /** Returns whether all production zones listed in deployment spec has this change (or a newer version, if upgrade) */ private boolean deploymentComplete(LockedApplication application) { - if ( ! application.deploying().isPresent()) return true; - Change change = application.deploying().get(); + if ( ! application.change().isPresent()) return true; + Change change = application.change(); for (JobType job : order.jobsFrom(application.deploymentSpec())) { if ( ! job.isProduction()) continue; @@ -130,18 +134,15 @@ public class DeploymentTrigger { if (deployment == null) return false; // Check actual job outcome (the deployment) - if (change instanceof VersionChange) { - if (((VersionChange)change).version().isAfter(deployment.version())) return false; // later is ok + if (change.platform().isPresent()) { + if (change.platform().get().isAfter(deployment.version())) return false; // later is ok } - else if (((Change.ApplicationChange)change).revision().isPresent()) { - if ( ! ((Change.ApplicationChange)change).revision().get().equals(deployment.revision())) return false; + if (change.application().isPresent()) { + // If we don't yet know the application version we are deploying, then we are not complete + if (change.application().get() == ApplicationVersion.unknown) return false; + if ( ! change.application().get().equals(deployment.applicationVersion())) return false; } - else { - return false; // If we don't yet know the revision we are changing to, then we are not complete - } - } - return true; } @@ -157,14 +158,14 @@ public class DeploymentTrigger { /** Find the next step to trigger if any, and triggers it */ public void triggerReadyJobs(LockedApplication application) { - if ( ! application.deploying().isPresent()) return; + if ( ! application.change().isPresent()) return; List<JobType> jobs = order.jobsFrom(application.deploymentSpec()); // Should the first step be triggered? if ( ! jobs.isEmpty() && jobs.get(0).equals(JobType.systemTest) ) { JobStatus systemTestStatus = application.deploymentJobs().jobStatus().get(JobType.systemTest); - if (application.deploying().get() instanceof Change.VersionChange) { - Version target = ((Change.VersionChange) application.deploying().get()).version(); + if (application.change().platform().isPresent()) { + Version target = application.change().platform().get(); if (systemTestStatus == null || ! systemTestStatus.lastTriggered().isPresent() || ! systemTestStatus.isSuccess() @@ -203,17 +204,15 @@ public class DeploymentTrigger { } /** - * Returns true if the previous job has completed successfully with a revision and/or version which is - * newer (different) than the one last completed successfully in next + * Returns true if the previous job has completed successfully with a application version and/or Vespa version + * which is newer (different) than the one last completed successfully in next */ private boolean changesAvailable(Application application, JobStatus previous, JobStatus next) { - if ( ! application.deploying().isPresent()) return false; + if ( ! application.change().isPresent()) return false; if (next == null) return true; - Change change = application.deploying().get(); - - if (change instanceof Change.VersionChange) { // Propagate upgrade while making sure we never downgrade - Version targetVersion = ((Change.VersionChange)change).version(); + if (application.change().platform().isPresent()) { // Propagate upgrade while making sure we never downgrade + Version targetVersion = application.change().platform().get(); if (next.type().isTest()) { // Is it not yet this job's turn to upgrade? @@ -242,11 +241,11 @@ public class DeploymentTrigger { return true; } - else { // revision changes do not need to handle downgrading + else { // Application version changes do not need to handle downgrading if ( ! previous.lastSuccess().isPresent()) return false; if ( ! next.lastSuccess().isPresent()) return true; - return previous.lastSuccess().get().revision().isPresent() && - ! previous.lastSuccess().get().revision().equals(next.lastSuccess().get().revision()); + return previous.lastSuccess().get().applicationVersion() != ApplicationVersion.unknown && + ! previous.lastSuccess().get().applicationVersion().equals(next.lastSuccess().get().applicationVersion()); } } @@ -258,14 +257,13 @@ public class DeploymentTrigger { */ public void triggerChange(ApplicationId applicationId, Change change) { applications().lockOrThrow(applicationId, application -> { - if (application.deploying().isPresent() && ! application.deploymentJobs().hasFailures()) + if (application.change().isPresent() && ! application.deploymentJobs().hasFailures()) throw new IllegalArgumentException("Could not start " + change + " on " + application + ": " + - application.deploying().get() + " is already in progress"); - application = application.withDeploying(Optional.of(change)); - if (change instanceof Change.ApplicationChange) + application.change() + " is already in progress"); + application = application.withDeploying(change); + if (change.application().isPresent()) application = application.withOutstandingChange(false); - application = trigger(JobType.systemTest, application, false, - (change instanceof Change.VersionChange ? "Upgrading to " + ((Change.VersionChange)change).version() : "Deploying " + change)); + application = trigger(JobType.systemTest, application, false, change.toString()); applications().store(application); }); } @@ -278,7 +276,7 @@ public class DeploymentTrigger { public void cancelChange(ApplicationId applicationId) { applications().lockOrThrow(applicationId, application -> { buildSystem.removeJobs(application.id()); - applications().store(application.withDeploying(Optional.empty())); + applications().store(application.withDeploying(Change.empty())); }); } @@ -342,7 +340,7 @@ public class DeploymentTrigger { if (jobType == null) return application; // we are passed null when the last job has been reached // Never allow untested changes to go through // Note that this may happen because a new change catches up and prevents an older one from continuing - if ( ! application.deploymentJobs().isDeployableTo(jobType.environment(), application.deploying())) { + if ( ! application.deploymentJobs().isDeployableTo(jobType.environment(), application.change())) { log.warning(String.format("Want to trigger %s for %s with reason %s, but change is untested", jobType, application, reason)); return application; @@ -350,14 +348,14 @@ public class DeploymentTrigger { if ( ! force && ! allowedTriggering(jobType, application)) return application; log.info(String.format("Triggering %s for %s, %s: %s", jobType, application, - application.deploying().map(d -> "deploying " + d).orElse("restarted deployment"), + application.change().isPresent() ? "deploying " + application.change() : "restarted deployment", reason)); buildSystem.addJob(application.id(), jobType, first); return application.withJobTriggering(jobType, - application.deploying(), + application.change(), clock.instant(), application.deployVersionFor(jobType, controller), - application.deployRevisionFor(jobType, controller), + application.deployApplicationVersion(jobType, controller).orElse(ApplicationVersion.unknown), reason); } @@ -368,12 +366,12 @@ public class DeploymentTrigger { // this leads to some additional corner cases, and the possibility of blocking an application // fix to a version upgrade, so not doing it now - if (jobType.isProduction() && application.deploying().isPresent() && - application.deploying().get().blockedBy(application.deploymentSpec(), clock.instant())) return false; + if (jobType.isProduction() && application.change().isPresent() && + application.change().blockedBy(application.deploymentSpec(), clock.instant())) return false; // Don't downgrade or redeploy the same version in production needlessly - if (application.deploying().isPresent() && application.deploying().get() instanceof VersionChange && - jobType.isProduction() && alreadyDeployed(((VersionChange) application.deploying().get()).version(), application, jobType)) return false; + if (application.change().platform().isPresent() && + jobType.isProduction() && alreadyDeployed((application.change().platform().get()), application, jobType)) return false; if (application.deploymentJobs().isRunning(jobType, jobTimeoutLimit())) return false; if ( ! hasJob(jobType, application)) return false; @@ -406,16 +404,17 @@ public class DeploymentTrigger { .orElse(false); } - private boolean acceptNewRevisionNow(LockedApplication application) { - if ( ! application.deploying().isPresent()) return true; + private boolean acceptNewApplicationVersionNow(LockedApplication application) { + if ( ! application.change().isPresent()) return true; - if (application.deploying().get() instanceof Change.ApplicationChange) return true; // more changes are ok + if (application.change().application().isPresent()) return true; // more application changes are ok if (application.deploymentJobs().hasFailures()) return true; // allow changes to fix upgrade problems if (application.isBlocked(clock.instant())) return true; // allow testing changes while upgrade blocked (debatable) - // Otherwise, the application is currently upgrading, without failures, and we should wait with the revision. + // Otherwise, the application is currently upgrading, without failures, and we should wait with the new + // application version. return false; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmer.java index ad373cf8e29..cc977295acf 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmer.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.maintenance; import com.yahoo.config.provision.ApplicationId; diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/OutstandingChangeDeployer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/OutstandingChangeDeployer.java index 4485a603f61..3dd63a511e1 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/OutstandingChangeDeployer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/OutstandingChangeDeployer.java @@ -4,17 +4,18 @@ package com.yahoo.vespa.hosted.controller.maintenance; import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.Controller; import com.yahoo.vespa.hosted.controller.application.ApplicationList; +import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; import java.time.Duration; /** * Deploys application changes which have been postponed due to an ongoing upgrade - * + * * @author bratseth */ public class OutstandingChangeDeployer extends Maintainer { - + public OutstandingChangeDeployer(Controller controller, Duration interval, JobControl jobControl) { super(controller, interval, jobControl); } @@ -23,9 +24,9 @@ public class OutstandingChangeDeployer extends Maintainer { protected void maintain() { ApplicationList applications = ApplicationList.from(controller().applications().asList()).notPullRequest(); for (Application application : applications.asList()) { - if (application.hasOutstandingChange() && ! application.deploying().isPresent()) + if (application.hasOutstandingChange() && ! application.change().isPresent()) controller().applications().deploymentTrigger().triggerChange(application.id(), - Change.ApplicationChange.unknown()); + Change.of(ApplicationVersion.unknown)); } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ReadyJobsTrigger.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ReadyJobsTrigger.java index f165b4e4ea3..314f52ca775 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ReadyJobsTrigger.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ReadyJobsTrigger.java @@ -1,19 +1,15 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.maintenance; -import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.Controller; -import com.yahoo.vespa.hosted.controller.application.ApplicationList; -import com.yahoo.vespa.hosted.controller.application.Change; import java.time.Duration; /** - * Deploys application changes which have not made it to production because of a revision change block. + * Trigger ready deployment jobs. This drives jobs through each application's deployment pipeline. * * @author bratseth */ -@SuppressWarnings("unused") public class ReadyJobsTrigger extends Maintainer { public ReadyJobsTrigger(Controller controller, Duration interval, JobControl jobControl) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java index 5b87f9eaa86..75f348904dd 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java @@ -20,7 +20,7 @@ import java.util.logging.Logger; /** * Maintenance job which schedules applications for Vespa version upgrade - * + * * @author bratseth * @author mpolden */ @@ -39,7 +39,7 @@ public class Upgrader extends Maintainer { * Schedule application upgrades. Note that this implementation must be idempotent. */ @Override - public void maintain() { + public void maintain() { // Determine target versions for each upgrade policy Optional<Version> canaryTarget = controller().versionStatus().systemVersion().map(VespaVersion::versionNumber); Optional<Version> defaultTarget = newestVersionWithConfidence(VespaVersion.Confidence.normal); @@ -66,26 +66,25 @@ public class Upgrader extends Maintainer { defaultTarget.ifPresent(target -> upgrade(applications().with(UpgradePolicy.defaultPolicy), target)); conservativeTarget.ifPresent(target -> upgrade(applications().with(UpgradePolicy.conservative), target)); } - + private Optional<Version> newestVersionWithConfidence(VespaVersion.Confidence confidence) { return reversed(controller().versionStatus().versions()).stream() .filter(v -> v.confidence().equalOrHigherThan(confidence)) .findFirst() .map(VespaVersion::versionNumber); } - + private List<VespaVersion> reversed(List<VespaVersion> versions) { List<VespaVersion> reversed = new ArrayList<>(versions.size()); for (int i = 0; i < versions.size(); i++) reversed.add(versions.get(versions.size() - 1 - i)); return reversed; } - + /** Returns a list of all applications */ private ApplicationList applications() { return ApplicationList.from(controller().applications().asList()); } - + private void upgrade(ApplicationList applications, Version version) { - Change.VersionChange change = new Change.VersionChange(version); applications = applications.notPullRequest(); // Pull requests are deployed as separate applications to test then deleted; No need to upgrade applications = applications.hasProductionDeployment(); applications = applications.onLowerVersionThan(version); @@ -96,7 +95,7 @@ public class Upgrader extends Maintainer { applications = applications.first(numberOfApplicationsToUpgrade()); // throttle upgrades for (Application application : applications.asList()) { try { - controller().applications().deploymentTrigger().triggerChange(application.id(), change); + controller().applications().deploymentTrigger().triggerChange(application.id(), Change.of(version)); } catch (IllegalArgumentException e) { log.log(Level.INFO, "Could not trigger change: " + Exceptions.toMessageString(e)); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java index 9c77ebc4bc3..652f95a2d13 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java @@ -15,7 +15,7 @@ import com.yahoo.vespa.config.SlimeUtils; import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService.ApplicationMetrics; import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId; -import com.yahoo.vespa.hosted.controller.application.ApplicationRevision; +import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.ClusterInfo; import com.yahoo.vespa.hosted.controller.application.ClusterUtilization; @@ -38,7 +38,7 @@ import java.util.Optional; /** * Serializes applications to/from slime. * This class is multithread safe. - * + * * @author bratseth */ public class ApplicationSerializer { @@ -67,12 +67,12 @@ public class ApplicationSerializer { private final String repositoryField = "repositoryField"; private final String branchField = "branchField"; private final String commitField = "commitField"; - + // DeploymentJobs fields private final String projectIdField = "projectId"; private final String jobStatusField = "jobStatus"; private final String issueIdField = "jiraIssueId"; - + // JobStatus field private final String jobTypeField = "jobType"; private final String errorField = "jobError"; @@ -80,7 +80,7 @@ public class ApplicationSerializer { private final String lastCompletedField = "lastCompleted"; private final String firstFailingField = "firstFailing"; private final String lastSuccessField = "lastSuccess"; - + // JobRun fields private final String jobRunIdField = "id"; private final String versionField = "version"; @@ -116,7 +116,7 @@ public class ApplicationSerializer { // ------------------ Serialization - + public Slime toSlime(Application application) { Slime slime = new Slime(); Cursor root = slime.setObject(); @@ -125,7 +125,7 @@ public class ApplicationSerializer { root.setString(validationOverridesField, application.validationOverrides().xmlForm()); deploymentsToSlime(application.deployments().values(), root.setArray(deploymentsField)); toSlime(application.deploymentJobs(), root.setObject(deploymentJobsField)); - toSlime(application.deploying(), root); + toSlime(application.change(), root); root.setBool(outstandingChangeField, application.hasOutstandingChange()); application.ownershipIssueId().ifPresent(issueId -> root.setString(ownershipIssueIdField, issueId.value())); root.setDouble(queryQualityField, application.metrics().queryServiceQuality()); @@ -138,12 +138,12 @@ public class ApplicationSerializer { for (Deployment deployment : deployments) deploymentToSlime(deployment, array.addObject()); } - + private void deploymentToSlime(Deployment deployment, Cursor object) { zoneIdToSlime(deployment.zone(), object.setObject(zoneField)); object.setString(versionField, deployment.version().toString()); object.setLong(deployTimeField, deployment.at().toEpochMilli()); - toSlime(deployment.revision(), object.setObject(applicationPackageRevisionField)); + toSlime(deployment.applicationVersion(), object.setObject(applicationPackageRevisionField)); clusterInfoToSlime(deployment.clusterInfo(), object); clusterUtilsToSlime(deployment.clusterUtils(), object); metricsToSlime(deployment.metrics(), object); @@ -196,19 +196,19 @@ public class ApplicationSerializer { object.setString(environmentField, zone.environment().value()); object.setString(regionField, zone.region().value()); } - - private void toSlime(ApplicationRevision applicationRevision, Cursor object) { - object.setString(applicationPackageHashField, applicationRevision.id()); - if (applicationRevision.source().isPresent()) - toSlime(applicationRevision.source().get(), object.setObject(sourceRevisionField)); + + private void toSlime(ApplicationVersion applicationVersion, Cursor object) { + object.setString(applicationPackageHashField, applicationVersion.id()); + if (applicationVersion.source().isPresent()) + toSlime(applicationVersion.source().get(), object.setObject(sourceRevisionField)); } - + private void toSlime(SourceRevision sourceRevision, Cursor object) { object.setString(repositoryField, sourceRevision.repository()); object.setString(branchField, sourceRevision.branch()); object.setString(commitField, sourceRevision.commit()); } - + private void toSlime(DeploymentJobs deploymentJobs, Cursor cursor) { deploymentJobs.projectId().ifPresent(projectId -> cursor.setLong(projectIdField, projectId)); jobStatusToSlime(deploymentJobs.jobStatus().values(), cursor.setArray(jobStatusField)); @@ -219,7 +219,7 @@ public class ApplicationSerializer { for (JobStatus jobStatus : jobStatuses) toSlime(jobStatus, jobStatusArray.addObject()); } - + private void toSlime(JobStatus jobStatus, Cursor object) { object.setString(jobTypeField, jobStatus.type().jobName()); if (jobStatus.jobError().isPresent()) @@ -230,40 +230,40 @@ public class ApplicationSerializer { jobRunToSlime(jobStatus.firstFailing(), object, firstFailingField); jobRunToSlime(jobStatus.lastSuccess(), object, lastSuccessField); } - + private void jobRunToSlime(Optional<JobStatus.JobRun> jobRun, Cursor parent, String jobRunObjectName) { if ( ! jobRun.isPresent()) return; Cursor object = parent.setObject(jobRunObjectName); object.setLong(jobRunIdField, jobRun.get().id()); object.setString(versionField, jobRun.get().version().toString()); - if ( jobRun.get().revision().isPresent()) - toSlime(jobRun.get().revision().get(), object.setObject(revisionField)); + if ( jobRun.get().applicationVersion() != ApplicationVersion.unknown) + toSlime(jobRun.get().applicationVersion(), object.setObject(revisionField)); object.setBool(upgradeField, jobRun.get().upgrade()); object.setString(reasonField, jobRun.get().reason()); object.setLong(atField, jobRun.get().at().toEpochMilli()); } - - private void toSlime(Optional<Change> deploying, Cursor parentObject) { + + private void toSlime(Change deploying, Cursor parentObject) { if ( ! deploying.isPresent()) return; Cursor object = parentObject.setObject(deployingField); - if (deploying.get() instanceof Change.VersionChange) - object.setString(versionField, ((Change.VersionChange)deploying.get()).version().toString()); - else if (((Change.ApplicationChange)deploying.get()).revision().isPresent()) - toSlime(((Change.ApplicationChange)deploying.get()).revision().get(), object); + if (deploying.platform().isPresent()) + object.setString(versionField, deploying.platform().get().toString()); + if (deploying.application().isPresent() && deploying.application().get() != ApplicationVersion.unknown) + toSlime(deploying.application().get(), object); } // ------------------ Deserialization public Application fromSlime(Slime slime) { Inspector root = slime.get(); - + ApplicationId id = ApplicationId.fromSerializedForm(root.field(idField).asString()); DeploymentSpec deploymentSpec = DeploymentSpec.fromXml(root.field(deploymentSpecField).asString(), false); ValidationOverrides validationOverrides = ValidationOverrides.fromXml(root.field(validationOverridesField).asString()); List<Deployment> deployments = deploymentsFromSlime(root.field(deploymentsField)); DeploymentJobs deploymentJobs = deploymentJobsFromSlime(root.field(deploymentJobsField)); - Optional<Change> deploying = changeFromSlime(root.field(deployingField)); + Change deploying = changeFromSlime(root.field(deployingField)); boolean outstandingChange = root.field(outstandingChangeField).asBool(); Optional<IssueId> ownershipIssueId = optionalString(root.field(ownershipIssueIdField)).map(IssueId::from); ApplicationMetrics metrics = new ApplicationMetrics(root.field(queryQualityField).asDouble(), @@ -282,7 +282,7 @@ public class ApplicationSerializer { private Deployment deploymentFromSlime(Inspector deploymentObject) { return new Deployment(zoneIdFromSlime(deploymentObject.field(zoneField)), - applicationRevisionFromSlime(deploymentObject.field(applicationPackageRevisionField)).get(), + applicationVersionFromSlime(deploymentObject.field(applicationPackageRevisionField)), Version.fromString(deploymentObject.field(versionField).asString()), Instant.ofEpochMilli(deploymentObject.field(deployTimeField).asLong()), clusterUtilsMapFromSlime(deploymentObject.field(clusterUtilsField)), @@ -340,14 +340,14 @@ public class ApplicationSerializer { return ZoneId.from(object.field(environmentField).asString(), object.field(regionField).asString()); } - private Optional<ApplicationRevision> applicationRevisionFromSlime(Inspector object) { - if ( ! object.valid()) return Optional.empty(); + private ApplicationVersion applicationVersionFromSlime(Inspector object) { + if ( ! object.valid()) return ApplicationVersion.unknown; String applicationPackageHash = object.field(applicationPackageHashField).asString(); Optional<SourceRevision> sourceRevision = sourceRevisionFromSlime(object.field(sourceRevisionField)); - return sourceRevision.isPresent() ? Optional.of(ApplicationRevision.from(applicationPackageHash, sourceRevision.get())) - : Optional.of(ApplicationRevision.from(applicationPackageHash)); + return sourceRevision.isPresent() ? ApplicationVersion.from(applicationPackageHash, sourceRevision.get()) + : ApplicationVersion.from(applicationPackageHash); } - + private Optional<SourceRevision> sourceRevisionFromSlime(Inspector object) { if ( ! object.valid()) return Optional.empty(); return Optional.of(new SourceRevision(object.field(repositoryField).asString(), @@ -363,23 +363,25 @@ public class ApplicationSerializer { return new DeploymentJobs(projectId, jobStatusList, issueId); } - private Optional<Change> changeFromSlime(Inspector object) { - if ( ! object.valid()) return Optional.empty(); + private Change changeFromSlime(Inspector object) { + if ( ! object.valid()) return Change.empty(); Inspector versionFieldValue = object.field(versionField); + Change change = Change.empty(); if (versionFieldValue.valid()) - return Optional.of(new Change.VersionChange(Version.fromString(versionFieldValue.asString()))); - else if (object.field(applicationPackageHashField).valid()) - return Optional.of(Change.ApplicationChange.of(applicationRevisionFromSlime(object).get())); - else - return Optional.of(Change.ApplicationChange.unknown()); + change = Change.of(Version.fromString(versionFieldValue.asString())); + if (object.field(applicationPackageHashField).valid()) + change = change.with(applicationVersionFromSlime(object)); + if ( ! change.isPresent()) // A deploy object with no fields -> unknown application change + change = Change.of(ApplicationVersion.unknown); + return change; } - + private List<JobStatus> jobStatusListFromSlime(Inspector array) { List<JobStatus> jobStatusList = new ArrayList<>(); array.traverse((ArrayTraverser) (int i, Inspector item) -> jobStatusList.add(jobStatusFromSlime(item))); return jobStatusList; } - + private JobStatus jobStatusFromSlime(Inspector object) { DeploymentJobs.JobType jobType = DeploymentJobs.JobType.fromJobName(object.field(jobTypeField).asString()); @@ -398,7 +400,7 @@ public class ApplicationSerializer { if ( ! object.valid()) return Optional.empty(); return Optional.of(new JobStatus.JobRun(optionalLong(object.field(jobRunIdField)).orElse(-1L), // TODO: Make non-optional after November 2017 -- what about lastTriggered? new Version(object.field(versionField).asString()), - applicationRevisionFromSlime(object.field(revisionField)), + applicationVersionFromSlime(object.field(revisionField)), object.field(upgradeField).asBool(), optionalString(object.field(reasonField)).orElse(""), // TODO: Make non-optional after November 2017 Instant.ofEpochMilli(object.field(atField).asLong()))); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index dc816d70b7f..9b2174b881d 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -54,7 +54,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.configserver.Log; import com.yahoo.vespa.hosted.controller.api.integration.organization.User; import com.yahoo.vespa.hosted.controller.api.integration.routing.RotationStatus; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; -import com.yahoo.vespa.hosted.controller.application.ApplicationRevision; +import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.ClusterCost; import com.yahoo.vespa.hosted.controller.application.ClusterUtilization; @@ -98,7 +98,7 @@ import java.util.logging.Level; /** * This implements the application/v4 API which is used to deploy and manage applications * on hosted Vespa. - * + * * @author bratseth * @author mpolden */ @@ -123,7 +123,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { public Duration getTimeout() { return Duration.ofMinutes(20); // deploys may take a long time; } - + @Override public HttpResponse handle(HttpRequest request) { try { @@ -156,7 +156,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { return ErrorResponse.internalServerError(Exceptions.toMessageString(e)); } } - + private HttpResponse handleGET(HttpRequest request) { Path path = new Path(request.getUri().getPath()); if (path.matches("/application/v4/")) return root(request); @@ -213,7 +213,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { return setGlobalRotationOverride(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), true, request); return ErrorResponse.notFoundError("Nothing at " + path); } - + private HttpResponse handleOPTIONS() { // We implement this to avoid redirect loops on OPTIONS requests from browsers, but do not really bother // spelling out the methods supported at each path, which we should @@ -235,7 +235,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { ? recursiveRoot(request) : new ResourceResponse(request, "user", "tenant", "tenant-pipeline", "athensDomain", "property", "cookiefreshness"); } - + private HttpResponse authenticatedUser(HttpRequest request) { String userIdString = request.getProperty("userOverride"); if (userIdString == null) @@ -243,7 +243,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { .map(UserId::id) .orElseThrow(() -> new ForbiddenException("You must be authenticated or specify userOverride")); UserId userId = new UserId(userIdString); - + List<Tenant> tenants = controller.tenants().asList(userId); Slime slime = new Slime(); @@ -255,7 +255,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { response.setBool("tenantExists", tenants.stream().map(Tenant::getId).anyMatch(id -> id.isTenantFor(userId))); return new SlimeJsonResponse(slime); } - + private HttpResponse tenants(HttpRequest request) { Slime slime = new Slime(); Cursor response = slime.setArray(); @@ -263,7 +263,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { tenantInTenantsListToSlime(tenant, request.getUri(), response.addObject()); return new SlimeJsonResponse(slime); } - + /** Lists the screwdriver project id for each application */ private HttpResponse tenantPipelines() { Slime slime = new Slime(); @@ -281,7 +281,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { response.setArray("brokenTenantPipelines"); // not used but may need to be present return new SlimeJsonResponse(slime); } - + private HttpResponse athenzDomains(HttpRequest request) { Slime slime = new Slime(); Cursor response = slime.setObject(); @@ -307,7 +307,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { private HttpResponse cookieFreshness(HttpRequest request) { Slime slime = new Slime(); String passThruHeader = request.getHeader(SetBouncerPassthruHeaderFilter.BOUNCER_PASSTHRU_HEADER_FIELD); - slime.setObject().setBool("shouldRefreshCookie", + slime.setObject().setBool("shouldRefreshCookie", ! SetBouncerPassthruHeaderFilter.BOUNCER_PASSTHRU_COOKIE_OK.equals(passThruHeader)); return new SlimeJsonResponse(slime); } @@ -332,7 +332,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { toSlime(application, array.addObject(), request); return new SlimeJsonResponse(slime); } - + private HttpResponse application(String tenantName, String applicationName, HttpRequest request) { ApplicationId applicationId = ApplicationId.from(tenantName, applicationName, "default"); Application application = @@ -348,12 +348,12 @@ public class ApplicationApiHandler extends LoggingRequestHandler { object.setString("application", application.id().application().value()); object.setString("instance", application.id().instance().value()); // Currently deploying change - if (application.deploying().isPresent()) { + if (application.change().isPresent()) { Cursor deployingObject = object.setObject("deploying"); - if (application.deploying().get() instanceof Change.VersionChange) - deployingObject.setString("version", ((Change.VersionChange)application.deploying().get()).version().toString()); - else if (((Change.ApplicationChange)application.deploying().get()).revision().isPresent()) - toSlime(((Change.ApplicationChange)application.deploying().get()).revision().get(), deployingObject.setObject("revision")); + application.change().platform().ifPresent(v -> deployingObject.setString("version", v.toString())); + application.change().application() + .filter(v -> v != ApplicationVersion.unknown) + .ifPresent(v -> toSlime(v, deployingObject.setObject("revision"))); } // Jobs sorted according to deployment spec @@ -453,14 +453,14 @@ public class ApplicationApiHandler extends LoggingRequestHandler { response.setString("yamasUrl", monitoringSystemUri(deploymentId).toString()); response.setString("version", deployment.version().toFullString()); - response.setString("revision", deployment.revision().id()); + response.setString("revision", deployment.applicationVersion().id()); response.setLong("deployTimeEpochMs", deployment.at().toEpochMilli()); controller.zoneRegistry().getDeploymentTimeToLive(deploymentId.zoneId()) .ifPresent(deploymentTimeToLive -> response.setLong("expiryTimeEpochMs", deployment.at().plus(deploymentTimeToLive).toEpochMilli())); controller.applications().get(deploymentId.applicationId()).flatMap(application -> application.deploymentJobs().projectId()) .ifPresent(i -> response.setString("screwdriverId", String.valueOf(i))); - sourceRevisionToSlime(deployment.revision().source(), response); + sourceRevisionToSlime(deployment.applicationVersion().source(), response); // Cost DeploymentCost appCost = deployment.calculateCost(); @@ -477,10 +477,10 @@ public class ApplicationApiHandler extends LoggingRequestHandler { metricsObject.setDouble("writeLatencyMillis", metrics.writeLatencyMillis()); } - private void toSlime(ApplicationRevision revision, Cursor object) { - object.setString("hash", revision.id()); - if (revision.source().isPresent()) - sourceRevisionToSlime(revision.source(), object.setObject("source")); + private void toSlime(ApplicationVersion applicationVersion, Cursor object) { + object.setString("hash", applicationVersion.id()); + if (applicationVersion.source().isPresent()) + sourceRevisionToSlime(applicationVersion.source(), object.setObject("source")); } private void sourceRevisionToSlime(Optional<SourceRevision> revision, Cursor object) { @@ -594,7 +594,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { response.setResponse(result, serviceName, restPath); return response; } - + private HttpResponse createUser(HttpRequest request) { Optional<UserId> user = userFrom(request); if ( ! user.isPresent() ) throw new ForbiddenException("Not authenticated."); @@ -711,11 +711,11 @@ public class ApplicationApiHandler extends LoggingRequestHandler { ApplicationId id = ApplicationId.from(tenantName, applicationName, "default"); controller.applications().lockOrThrow(id, application -> { - if (application.deploying().isPresent()) + if (application.change().isPresent()) throw new IllegalArgumentException("Can not start a deployment of " + application + " at this time: " + - application.deploying().get() + " is in progress"); + application.change() + " is in progress"); - controller.applications().deploymentTrigger().triggerChange(application.id(), new Change.VersionChange(version)); + controller.applications().deploymentTrigger().triggerChange(application.id(), Change.of(version)); }); return new MessageResponse("Triggered deployment of application '" + id + "' on version " + version); } @@ -724,14 +724,14 @@ public class ApplicationApiHandler extends LoggingRequestHandler { private HttpResponse cancelDeploy(String tenantName, String applicationName) { ApplicationId id = ApplicationId.from(tenantName, applicationName, "default"); Application application = controller.applications().require(id); - Optional<Change> change = application.deploying(); + Change change = application.change(); if ( ! change.isPresent()) return new MessageResponse("No deployment in progress for " + application + " at this time"); controller.applications().lockOrThrow(id, lockedApplication -> controller.applications().deploymentTrigger().cancelChange(id)); - return new MessageResponse("Cancelled " + change.get() + " for " + application); + return new MessageResponse("Cancelled " + change + " for " + application); } /** Schedule restart of deployment, or specific host in a deployment */ @@ -775,12 +775,11 @@ public class ApplicationApiHandler extends LoggingRequestHandler { Map<String, byte[]> dataParts = new MultipartParser().parse(request); if ( ! dataParts.containsKey("deployOptions")) return ErrorResponse.badRequest("Missing required form part 'deployOptions'"); - if ( ! dataParts.containsKey("applicationZip")) - return ErrorResponse.badRequest("Missing required form part 'applicationZip'"); Inspector deployOptions = SlimeUtils.jsonToSlime(dataParts.get("deployOptions")).get(); - ApplicationPackage applicationPackage = new ApplicationPackage(dataParts.get("applicationZip")); + Optional<ApplicationPackage> applicationPackage = Optional.ofNullable(dataParts.get("applicationZip")) + .map(ApplicationPackage::new); DeployAuthorizer deployAuthorizer = new DeployAuthorizer(controller.zoneRegistry(), athenzClientFactory); Tenant tenant = controller.tenants().tenant(new TenantId(tenantName)).orElseThrow(() -> new NotExistsException(new TenantId(tenantName))); Principal principal = authorizer.getPrincipal(request); @@ -791,12 +790,11 @@ public class ApplicationApiHandler extends LoggingRequestHandler { optional("vespaVersion", deployOptions).map(Version::new), deployOptions.field("ignoreValidationErrors").asBool(), deployOptions.field("deployCurrentVersion").asBool()); - controller.applications().validate(applicationPackage.deploymentSpec()); ActivateResult result = controller.applications().deployApplication(applicationId, zone, applicationPackage, deployOptionsJsonClass); - return new SlimeJsonResponse(toSlime(result, dataParts.get("applicationZip").length)); + return new SlimeJsonResponse(toSlime(result)); } private HttpResponse deleteTenant(String tenantName, HttpRequest request) { @@ -972,7 +970,8 @@ public class ApplicationApiHandler extends LoggingRequestHandler { private void toSlime(JobStatus.JobRun jobRun, Cursor object) { object.setLong("id", jobRun.id()); object.setString("version", jobRun.version().toFullString()); - jobRun.revision().ifPresent(revision -> toSlime(revision, object.setObject("revision"))); + if (jobRun.applicationVersion() != ApplicationVersion.unknown) + toSlime(jobRun.applicationVersion(), object.setObject("revision")); object.setString("reason", jobRun.reason()); object.setLong("at", jobRun.at().toEpochMilli()); } @@ -1027,14 +1026,14 @@ public class ApplicationApiHandler extends LoggingRequestHandler { "/application/" + application.id().application().value(), request.getUri()).toString()); } - private Slime toSlime(ActivateResult result, long applicationZipSizeBytes) { + private Slime toSlime(ActivateResult result) { Slime slime = new Slime(); Cursor object = slime.setObject(); - object.setString("revisionId", result.getRevisionId().id()); - object.setLong("applicationZipSize", applicationZipSizeBytes); + object.setString("revisionId", result.revisionId().id()); + object.setLong("applicationZipSize", result.applicationZipSizeBytes()); Cursor logArray = object.setArray("prepareMessages"); - if (result.getPrepareResponse().log != null) { - for (Log logMessage : result.getPrepareResponse().log) { + if (result.prepareResponse().log != null) { + for (Log logMessage : result.prepareResponse().log) { Cursor logObject = logArray.addObject(); logObject.setLong("time", logMessage.time); logObject.setString("level", logMessage.level); @@ -1045,7 +1044,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { Cursor changeObject = object.setObject("configChangeActions"); Cursor restartActionsArray = changeObject.setArray("restart"); - for (RestartAction restartAction : result.getPrepareResponse().configChangeActions.restartActions) { + for (RestartAction restartAction : result.prepareResponse().configChangeActions.restartActions) { Cursor restartActionObject = restartActionsArray.addObject(); restartActionObject.setString("clusterName", restartAction.clusterName); restartActionObject.setString("clusterType", restartAction.clusterType); @@ -1055,7 +1054,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { } Cursor refeedActionsArray = changeObject.setArray("refeed"); - for (RefeedAction refeedAction : result.getPrepareResponse().configChangeActions.refeedActions) { + for (RefeedAction refeedAction : result.prepareResponse().configChangeActions.refeedActions) { Cursor refeedActionObject = refeedActionsArray.addObject(); refeedActionObject.setString("name", refeedAction.name); refeedActionObject.setBool("allowed", refeedAction.allowed); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/DeployAuthorizer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/DeployAuthorizer.java index 323da24b47d..ee8deef7256 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/DeployAuthorizer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/DeployAuthorizer.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.restapi.application; +import com.yahoo.config.application.api.DeploymentSpec; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.Environment; import com.yahoo.vespa.athenz.api.AthenzDomain; @@ -16,6 +17,7 @@ import javax.ws.rs.ForbiddenException; import javax.ws.rs.NotAuthorizedException; import java.security.Principal; import java.util.Objects; +import java.util.Optional; import java.util.logging.Logger; import static com.yahoo.vespa.hosted.controller.api.integration.athenz.HostedAthenzIdentities.SCREWDRIVER_DOMAIN; @@ -41,9 +43,10 @@ public class DeployAuthorizer { Environment environment, Tenant tenant, ApplicationId applicationId, - ApplicationPackage applicationPackage) { + Optional<ApplicationPackage> applicationPackage) { // Validate that domain in identity configuration (deployment.xml) is same as tenant domain - applicationPackage.deploymentSpec().athenzDomain().ifPresent(identityDomain -> { + applicationPackage.map(ApplicationPackage::deploymentSpec).flatMap(DeploymentSpec::athenzDomain) + .ifPresent(identityDomain -> { AthenzDomain tenantDomain = tenant.getAthensDomain().orElseThrow(() -> new IllegalArgumentException("Identity provider only available to Athenz onboarded tenants")); if (! Objects.equals(tenantDomain.getName(), identityDomain.value())) { throw new ForbiddenException( diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ArtifactRepositoryMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ArtifactRepositoryMock.java new file mode 100644 index 00000000000..efbc10e8deb --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ArtifactRepositoryMock.java @@ -0,0 +1,39 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller; + +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.vespa.hosted.controller.api.integration.deployment.ArtifactRepository; +import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * @author mpolden + */ +public class ArtifactRepositoryMock implements ArtifactRepository { + + private final Map<Integer, byte[]> repository = new HashMap<>(); + + public ArtifactRepositoryMock put(ApplicationId applicationId, ApplicationPackage applicationPackage, + String applicationVersion) { + repository.put(artifactHash(applicationId, applicationVersion), applicationPackage.zippedContent()); + return this; + } + + @Override + public byte[] getApplicationPackage(ApplicationId applicationId, String applicationVersion) { + int artifactHash = artifactHash(applicationId, applicationVersion); + if (!repository.containsKey(artifactHash)) { + throw new IllegalArgumentException("No application package found for " + applicationId + " with version " + + applicationVersion); + } + return repository.get(artifactHash); + } + + private static int artifactHash(ApplicationId applicationId, String applicationVersion) { + return Objects.hash(applicationId, applicationVersion); + } + +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java index ceed52d2dad..8d03b4f7121 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java @@ -10,11 +10,11 @@ import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.athenz.api.AthenzDomain; import com.yahoo.vespa.config.SlimeUtils; import com.yahoo.vespa.hosted.controller.api.Tenant; import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions; import com.yahoo.vespa.hosted.controller.api.application.v4.model.EndpointStatus; -import com.yahoo.vespa.athenz.api.AthenzDomain; import com.yahoo.vespa.hosted.controller.api.identifiers.DeploymentId; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; @@ -26,12 +26,13 @@ import com.yahoo.vespa.hosted.controller.api.integration.dns.Record; import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordName; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; -import com.yahoo.vespa.hosted.controller.application.ApplicationRevision; +import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobError; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobType; import com.yahoo.vespa.hosted.controller.application.JobStatus; +import com.yahoo.vespa.hosted.controller.application.SourceRevision; import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzDbMock; import com.yahoo.vespa.hosted.controller.deployment.ApplicationPackageBuilder; import com.yahoo.vespa.hosted.controller.deployment.BuildSystem; @@ -100,21 +101,22 @@ public class ControllerTest { Version version1 = Version.fromString("6.1"); // Set in config server mock Application app1 = tester.createApplication("app1", "tenant1", 1, 11L); tester.notifyJobCompletion(component, app1, true); - assertFalse("Revision is currently not known", - ((Change.ApplicationChange)tester.controller().applications().require(app1.id()).deploying().get()).revision().isPresent()); + assertEquals("Application version is currently not known", + ApplicationVersion.unknown, + tester.controller().applications().require(app1.id()).change().application().get()); tester.deployAndNotify(app1, applicationPackage, true, systemTest); tester.deployAndNotify(app1, applicationPackage, true, stagingTest); assertEquals(4, applications.require(app1.id()).deploymentJobs().jobStatus().size()); - Optional<ApplicationRevision> revision = ((Change.ApplicationChange)tester.controller().applications().require(app1.id()).deploying().get()).revision(); - assertTrue("Revision has been set during deployment", revision.isPresent()); + ApplicationVersion applicationVersion = tester.controller().applications().require(app1.id()).change().application().get(); + assertTrue("Application version has been set during deployment", applicationVersion != ApplicationVersion.unknown); assertStatus(JobStatus.initial(stagingTest) - .withTriggering(version1, revision, false, "", tester.clock().instant().minus(Duration.ofMillis(1))) + .withTriggering(version1, applicationVersion, false, "", tester.clock().instant().minus(Duration.ofMillis(1))) .withCompletion(42, Optional.empty(), tester.clock().instant(), tester.controller()), app1.id(), tester.controller()); // Causes first deployment job to be triggered assertStatus(JobStatus.initial(productionCorpUsEast1) - .withTriggering(version1, revision, false, "", tester.clock().instant()), app1.id(), tester.controller()); + .withTriggering(version1, applicationVersion, false, "", tester.clock().instant()), app1.id(), tester.controller()); tester.clock().advance(Duration.ofSeconds(1)); // production job (failing) @@ -122,9 +124,9 @@ public class ControllerTest { assertEquals(4, applications.require(app1.id()).deploymentJobs().jobStatus().size()); JobStatus expectedJobStatus = JobStatus.initial(productionCorpUsEast1) - .withTriggering(version1, revision, false, "", tester.clock().instant()) // Triggered first without revision info + .withTriggering(version1, applicationVersion, false, "", tester.clock().instant()) // Triggered first without application version info .withCompletion(42, Optional.of(JobError.unknown), tester.clock().instant(), tester.controller()) - .withTriggering(version1, revision, false, "", tester.clock().instant()); // Re-triggering (due to failure) has revision info + .withTriggering(version1, applicationVersion, false, "", tester.clock().instant()); // Re-triggering (due to failure) has application version info assertStatus(expectedJobStatus, app1.id(), tester.controller()); @@ -147,25 +149,25 @@ public class ControllerTest { tester.notifyJobCompletion(component, app1, true); tester.deployAndNotify(app1, applicationPackage, true, false, systemTest); assertStatus(JobStatus.initial(systemTest) - .withTriggering(version1, revision, false, "", tester.clock().instant().minus(Duration.ofMillis(1))) + .withTriggering(version1, applicationVersion, false, "", tester.clock().instant().minus(Duration.ofMillis(1))) .withCompletion(42, Optional.empty(), tester.clock().instant(), tester.controller()), app1.id(), tester.controller()); tester.deployAndNotify(app1, applicationPackage, true, stagingTest); // production job succeeding now tester.deployAndNotify(app1, applicationPackage, true, productionCorpUsEast1); expectedJobStatus = expectedJobStatus - .withTriggering(version1, revision, false, "", tester.clock().instant().minus(Duration.ofMillis(1))) + .withTriggering(version1, applicationVersion, false, "", tester.clock().instant().minus(Duration.ofMillis(1))) .withCompletion(42, Optional.empty(), tester.clock().instant(), tester.controller()); assertStatus(expectedJobStatus, app1.id(), tester.controller()); // causes triggering of next production job assertStatus(JobStatus.initial(productionUsEast3) - .withTriggering(version1, revision, false, "", tester.clock().instant()), + .withTriggering(version1, applicationVersion, false, "", tester.clock().instant()), app1.id(), tester.controller()); tester.deployAndNotify(app1, applicationPackage, true, productionUsEast3); assertEquals(5, applications.get(app1.id()).get().deploymentJobs().jobStatus().size()); - + // prod zone removal is not allowed applicationPackage = new ApplicationPackageBuilder() .environment(Environment.prod) @@ -177,7 +179,7 @@ public class ControllerTest { fail("Expected exception due to unallowed production deployment removal"); } catch (IllegalArgumentException e) { - assertEquals("deployment-removal: application 'tenant1.app1' is deployed in corp-us-east-1, but does not include this zone in deployment.xml", e.getMessage()); + assertEquals("deployment-removal: application 'tenant1.app1' is deployed in corp-us-east-1, but does not include this zone in deployment.xml", e.getMessage()); } assertNotNull("Zone was not removed", applications.require(app1.id()).deployments().get(productionCorpUsEast1.zone(SystemName.main).get())); @@ -198,7 +200,132 @@ public class ControllerTest { applications.require(app1.id()).deployments().get(productionCorpUsEast1.zone(SystemName.main).get())); assertNull("Deployment job was removed", applications.require(app1.id()).deploymentJobs().jobStatus().get(productionCorpUsEast1)); } - + + // TODO: Replace above test with this one after introducing new application version number + @Test + public void testDeploymentWithApplicationVersion() { + // Setup system + DeploymentTester tester = new DeploymentTester(); + ApplicationController applications = tester.controller().applications(); + Version version1 = Version.fromString("6.1"); // Set in config server mock + Application app1 = tester.createApplication("app1", "tenant1", 1, 11L); + + // Component runs, uploads artifact and notifies completion + ApplicationPackage applicationPackage = new ApplicationPackageBuilder() + .environment(Environment.prod) + .region("corp-us-east-1") + .region("us-east-3") + .build(); + SourceRevision source = new SourceRevision("repo", "branch", "deadbeef"); + String expectedVersionString = "1.0.37-deadbeef"; + tester.artifactRepository().put(app1.id(), applicationPackage, expectedVersionString); + tester.notifyJobCompletion(component, app1, Optional.empty(), Optional.of(source), 37); + ApplicationVersion expectedVersion = ApplicationVersion.from(source, 37); + assertEquals(expectedVersionString, tester.controller().applications() + .require(app1.id()) + .change().application().get().id()); + + // Deploy without application package + tester.deployAndNotify(app1, true, systemTest); + tester.deployAndNotify(app1, true, stagingTest); + assertEquals(4, applications.require(app1.id()).deploymentJobs().jobStatus().size()); + assertStatus(JobStatus.initial(stagingTest) + .withTriggering(version1, expectedVersion, false, "", tester.clock().instant().minus(Duration.ofMillis(1))) + .withCompletion(42, Optional.empty(), tester.clock().instant(), tester.controller()), app1.id(), tester.controller()); + + // Causes first deployment job to be triggered + assertStatus(JobStatus.initial(productionCorpUsEast1) + .withTriggering(version1, expectedVersion, false, "", tester.clock().instant()), app1.id(), tester.controller()); + tester.clock().advance(Duration.ofSeconds(1)); + + // production job (failing) + tester.deployAndNotify(app1, false, productionCorpUsEast1); + assertEquals(4, applications.require(app1.id()).deploymentJobs().jobStatus().size()); + + JobStatus expectedJobStatus = JobStatus.initial(productionCorpUsEast1) + .withTriggering(version1, expectedVersion, false, "", tester.clock().instant()) + .withCompletion(42, Optional.of(JobError.unknown), tester.clock().instant(), tester.controller()); + + assertStatus(expectedJobStatus, app1.id(), tester.controller()); + + // Simulate restart + tester.restartController(); + applications = tester.controller().applications(); + + assertNotNull(tester.controller().tenants().tenant(new TenantId("tenant1"))); + assertNotNull(applications.get(ApplicationId.from(TenantName.from("tenant1"), + ApplicationName.from("application1"), + InstanceName.from("default")))); + assertEquals(4, applications.require(app1.id()).deploymentJobs().jobStatus().size()); + + + tester.clock().advance(Duration.ofHours(1)); + + tester.notifyJobCompletion(productionCorpUsEast1, app1, false); // Need to complete the job, or new jobs won't start. + + // Component is triggered again + tester.artifactRepository().put(app1.id(), applicationPackage, "1.0.38-deadbeef"); + tester.notifyJobCompletion(component, app1, Optional.empty(), Optional.of(source), 38); + tester.deployAndNotify(app1, Optional.empty(), true, false, systemTest); + expectedVersion = ApplicationVersion.from(source, 38); + assertStatus(JobStatus.initial(systemTest) + .withTriggering(version1, expectedVersion, false, "", tester.clock().instant().minus(Duration.ofMillis(1))) + .withCompletion(42, Optional.empty(), tester.clock().instant(), tester.controller()), app1.id(), tester.controller()); + tester.deployAndNotify(app1, Optional.empty(), true, true, stagingTest); + + // production job succeeding now + tester.deployAndNotify(app1, Optional.empty(), true, true, productionCorpUsEast1); + expectedJobStatus = expectedJobStatus + .withTriggering(version1, expectedVersion, false, "", tester.clock().instant().minus(Duration.ofMillis(1))) + .withCompletion(42, Optional.empty(), tester.clock().instant(), tester.controller()); + assertStatus(expectedJobStatus, app1.id(), tester.controller()); + + // causes triggering of next production job + assertStatus(JobStatus.initial(productionUsEast3) + .withTriggering(version1, expectedVersion, false, "", tester.clock().instant()), + app1.id(), tester.controller()); + tester.deployAndNotify(app1, Optional.empty(), true, true, productionUsEast3); + + assertEquals(5, applications.get(app1.id()).get().deploymentJobs().jobStatus().size()); + + // prod zone removal is not allowed + applicationPackage = new ApplicationPackageBuilder() + .environment(Environment.prod) + .region("us-east-3") + .build(); + tester.artifactRepository().put(app1.id(), applicationPackage, "1.0.56-cafed00d"); + source = new SourceRevision("repo", "branch", "cafed00d"); + tester.notifyJobCompletion(component, app1, Optional.empty(), Optional.of(source), 56); + try { + tester.deploy(systemTest, app1, Optional.empty(), false); + fail("Expected exception due to unallowed production deployment removal"); + } + catch (IllegalArgumentException e) { + assertEquals("deployment-removal: application 'tenant1.app1' is deployed in corp-us-east-1, but does not include this zone in deployment.xml", e.getMessage()); + } + assertNotNull("Zone was not removed", + applications.require(app1.id()).deployments().get(productionCorpUsEast1.zone(SystemName.main).get())); + JobStatus jobStatus = applications.require(app1.id()).deploymentJobs().jobStatus().get(productionCorpUsEast1); + assertNotNull("Deployment job was not removed", jobStatus); + assertEquals(42, jobStatus.lastCompleted().get().id()); + assertEquals("staging-test completed", jobStatus.lastCompleted().get().reason()); + + // prod zone removal is allowed with override + applicationPackage = new ApplicationPackageBuilder() + .allow(ValidationId.deploymentRemoval) + .upgradePolicy("default") + .environment(Environment.prod) + .region("us-east-3") + .build(); + tester.artifactRepository().put(app1.id(), applicationPackage, "1.0.103-c00ffefe"); + source = new SourceRevision("repo", "branch", "c00ffefe"); + tester.notifyJobCompletion(component, app1, Optional.empty(), Optional.of(source), 103); + tester.deployAndNotify(app1, Optional.empty(), true, true, systemTest); + assertNull("Zone was removed", + applications.require(app1.id()).deployments().get(productionCorpUsEast1.zone(SystemName.main).get())); + assertNull("Deployment job was removed", applications.require(app1.id()).deploymentJobs().jobStatus().get(productionCorpUsEast1)); + } + @Test public void testDeployVersion() { // Setup system @@ -250,11 +377,10 @@ public class ControllerTest { app1 = applications.require(app1.id()); assertEquals("Application change preserves version", systemVersion, app1.oldestDeployedVersion().get()); assertEquals(systemVersion, tester.configServer().lastPrepareVersion().get()); - assertFalse("Change deployed", app1.deploying().isPresent()); + assertFalse("Change deployed", app1.change().isPresent()); // Version upgrade changes system version - Change.VersionChange change = new Change.VersionChange(newSystemVersion); - applications.deploymentTrigger().triggerChange(app1.id(), change); + applications.deploymentTrigger().triggerChange(app1.id(), Change.of(newSystemVersion)); tester.deployAndNotify(app1, applicationPackage, true, systemTest); tester.deployAndNotify(app1, applicationPackage, true, stagingTest); tester.deployAndNotify(app1, applicationPackage, true, productionUsWest1); @@ -288,7 +414,7 @@ public class ControllerTest { controller.updateVersionStatus(new VersionStatus(versions)); return newSystemVersion; } - + @Test public void testPullRequestDeployment() { // Setup system @@ -300,7 +426,7 @@ public class ControllerTest { ApplicationId app1 = tester.createAndDeploy("tenant1", "domain1", "application1", Environment.staging, app1ProjectId).id(); - + // pull-request deployment - uses different instance id ApplicationId app1pr = tester.createAndDeploy("tenant1", "domain1", "application1", "default-pr1", @@ -334,7 +460,7 @@ public class ControllerTest { .filter(app -> app.id().application().equals(app2.application())) .count()); } - + @Test public void testFailingSinceUpdates() { // Setup system @@ -347,13 +473,13 @@ public class ControllerTest { Instant initialFailure = tester.clock().instant(); tester.notifyJobCompletion(component, app, true); tester.deployAndNotify(app, applicationPackage, false, systemTest); - assertEquals("Failure age is right at initial failure", + assertEquals("Failure age is right at initial failure", initialFailure.plus(Duration.ofMillis(2)), firstFailing(app, tester).get().at()); // Failure again -- failingSince should remain the same tester.clock().advance(Duration.ofMillis(1000)); tester.deployAndNotify(app, applicationPackage, false, systemTest); - assertEquals("Failure age is right at second consecutive failure", + assertEquals("Failure age is right at second consecutive failure", initialFailure.plus(Duration.ofMillis(2)), firstFailing(app, tester).get().at()); // Success resets failingSince @@ -364,27 +490,27 @@ public class ControllerTest { // Complete deployment tester.deployAndNotify(app, applicationPackage, true, stagingTest); tester.deployAndNotify(app, applicationPackage, true, productionCorpUsEast1); - + // Two repeated failures again. // Initial failure tester.clock().advance(Duration.ofMillis(1000)); initialFailure = tester.clock().instant(); tester.notifyJobCompletion(component, app, true); tester.deployAndNotify(app, applicationPackage, false, systemTest); - assertEquals("Failure age is right at initial failure", + assertEquals("Failure age is right at initial failure", initialFailure.plus(Duration.ofMillis(2)), firstFailing(app, tester).get().at()); // Failure again -- failingSince should remain the same tester.clock().advance(Duration.ofMillis(1000)); tester.deployAndNotify(app, applicationPackage, false, systemTest); - assertEquals("Failure age is right at second consecutive failure", + assertEquals("Failure age is right at second consecutive failure", initialFailure.plus(Duration.ofMillis(2)), firstFailing(app, tester).get().at()); } private Optional<JobStatus.JobRun> firstFailing(Application application, DeploymentTester tester) { return tester.controller().applications().get(application.id()).get().deploymentJobs().jobStatus().get(systemTest).firstFailing(); } - + @Test public void testMigratingTenantToAthenzWillModifyAthenzDomainsCorrectly() { ControllerTester tester = new ControllerTester(); @@ -542,13 +668,13 @@ public class ControllerTest { Application app = tester.createApplication(tenant, "app1", "default", 1); tester.controller().applications().lockOrThrow(app.id(), application -> { - application = application.withDeploying(Optional.of(new Change.VersionChange(Version.fromString("6.3")))); + application = application.withDeploying(Change.of(Version.fromString("6.3"))); applications.store(application); try { tester.deploy(app, ZoneId.from("prod", "us-east-3")); fail("Expected exception"); } catch (IllegalArgumentException e) { - assertEquals("Rejecting deployment of application 'tenant1.app1' to zone prod.us-east-3 as version change to 6.3 is not tested", e.getMessage()); + assertEquals("Rejecting deployment of application 'tenant1.app1' to zone prod.us-east-3 as upgrade to 6.3 is not tested", e.getMessage()); } }); } @@ -557,6 +683,7 @@ public class ControllerTest { public void testCleanupOfStaleDeploymentData() throws IOException { DeploymentTester tester = new DeploymentTester(); tester.controllerTester().zoneRegistry().setSystem(SystemName.cd); + tester.controllerTester().zoneRegistry().setZones(ZoneId.from("prod", "cd-us-central-1")); Supplier<Map<JobType, JobStatus>> statuses = () -> tester.application(ApplicationId.from("vespa", "canary", "default")).deploymentJobs().jobStatus(); @@ -764,6 +891,7 @@ public class ControllerTest { public void testDeployWithoutProjectId() { DeploymentTester tester = new DeploymentTester(); tester.controllerTester().zoneRegistry().setSystem(SystemName.cd); + tester.controllerTester().zoneRegistry().setZones(ZoneId.from("prod", "cd-us-central-1")); ApplicationPackage applicationPackage = new ApplicationPackageBuilder() .environment(Environment.prod) .region("cd-us-central-1") @@ -777,7 +905,7 @@ public class ControllerTest { // Same options as used in our integration tests DeployOptions options = new DeployOptions(Optional.empty(), Optional.empty(), false, false); - tester.controller().applications().deployApplication(app.id(), zone, applicationPackage, options); + tester.controller().applications().deployApplication(app.id(), zone, Optional.of(applicationPackage), options); assertTrue("Application deployed and activated", tester.controllerTester().configServer().activated().getOrDefault(app.id(), false)); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java index b1486c8ec00..3b574ac606b 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java @@ -7,6 +7,7 @@ import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.hosted.controller.api.integration.deployment.ArtifactRepository; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.slime.Slime; import com.yahoo.test.ManualClock; @@ -64,30 +65,32 @@ public final class ControllerTester { private final CuratorDb curator; private final MemoryNameService nameService; private final RotationsConfig rotationsConfig; + private final ArtifactRepositoryMock artifactRepository; private Controller controller; public ControllerTester() { this(new MemoryControllerDb(), new AthenzDbMock(), new ManualClock(), new ConfigServerClientMock(), new ZoneRegistryMock(), new GitHubMock(), new MockCuratorDb(), defaultRotationsConfig(), - new MemoryNameService()); + new MemoryNameService(), new ArtifactRepositoryMock()); } public ControllerTester(ManualClock clock) { this(new MemoryControllerDb(), new AthenzDbMock(), clock, new ConfigServerClientMock(), new ZoneRegistryMock(), new GitHubMock(), new MockCuratorDb(), defaultRotationsConfig(), - new MemoryNameService()); + new MemoryNameService(), new ArtifactRepositoryMock()); } public ControllerTester(RotationsConfig rotationsConfig) { this(new MemoryControllerDb(), new AthenzDbMock(), new ManualClock(), new ConfigServerClientMock(), - new ZoneRegistryMock(), new GitHubMock(), new MockCuratorDb(), rotationsConfig, new MemoryNameService()); + new ZoneRegistryMock(), new GitHubMock(), new MockCuratorDb(), rotationsConfig, new MemoryNameService(), + new ArtifactRepositoryMock()); } private ControllerTester(ControllerDb db, AthenzDbMock athenzDb, ManualClock clock, ConfigServerClientMock configServer, ZoneRegistryMock zoneRegistry, GitHubMock gitHub, CuratorDb curator, RotationsConfig rotationsConfig, - MemoryNameService nameService) { + MemoryNameService nameService, ArtifactRepositoryMock artifactRepository) { this.db = db; this.athenzDb = athenzDb; this.clock = clock; @@ -97,8 +100,9 @@ public final class ControllerTester { this.curator = curator; this.nameService = nameService; this.rotationsConfig = rotationsConfig; + this.artifactRepository = artifactRepository; this.controller = createController(db, curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, - athenzDb, nameService); + athenzDb, nameService, artifactRepository); } public Controller controller() { return controller; } @@ -117,10 +121,12 @@ public final class ControllerTester { public GitHubMock gitHub() { return gitHub; } + public ArtifactRepositoryMock artifactRepository() { return artifactRepository; } + /** Create a new controller instance. Useful to verify that controller state is rebuilt from persistence */ public final void createNewController() { controller = createController(db, curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, athenzDb, - nameService); + nameService, artifactRepository); } /** Creates the given tenant and application and deploys it */ @@ -209,6 +215,10 @@ public final class ControllerTester { } public void deploy(Application application, ZoneId zone, ApplicationPackage applicationPackage, boolean deployCurrentVersion) { + deploy(application, zone, Optional.of(applicationPackage), deployCurrentVersion); + } + + public void deploy(Application application, ZoneId zone, Optional<ApplicationPackage> applicationPackage, boolean deployCurrentVersion) { ScrewdriverId app1ScrewdriverId = new ScrewdriverId(String.valueOf(application.deploymentJobs().projectId().get())); GitRevision app1RevisionId = new GitRevision(new GitRepository("repo"), new GitBranch("master"), new GitCommit("commit1")); controller().applications().deployApplication(application.id(), @@ -231,7 +241,8 @@ public final class ControllerTester { private static Controller createController(ControllerDb db, CuratorDb curator, RotationsConfig rotationsConfig, ConfigServerClientMock configServerClientMock, ManualClock clock, GitHubMock gitHubClientMock, ZoneRegistryMock zoneRegistryMock, - AthenzDbMock athensDb, MemoryNameService nameService) { + AthenzDbMock athensDb, MemoryNameService nameService, + ArtifactRepository artifactRepository) { Controller controller = new Controller(db, curator, rotationsConfig, @@ -247,7 +258,8 @@ public final class ControllerTester { new MockRoutingGenerator(), new ChefMock(), clock, - new AthenzClientFactoryMock(athensDb)); + new AthenzClientFactoryMock(athensDb), + artifactRepository); controller.updateVersionStatus(VersionStatus.compute(controller)); return controller; } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ZoneRegistryMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ZoneRegistryMock.java index c205357c7ef..63751cfaa98 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ZoneRegistryMock.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ZoneRegistryMock.java @@ -16,6 +16,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneFilterMock; import java.net.URI; import java.time.Duration; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -34,9 +35,11 @@ public class ZoneRegistryMock extends AbstractComponent implements ZoneRegistry @Inject public ZoneRegistryMock() { - this.zones.add(ZoneId.from("prod", "corp-us-east-1")); - this.zones.add(ZoneId.from("prod", "us-east-3")); - this.zones.add(ZoneId.from("prod", "us-west-1")); + zones.add(ZoneId.from("prod", "corp-us-east-1")); + zones.add(ZoneId.from("prod", "us-east-3")); + zones.add(ZoneId.from("prod", "us-west-1")); + zones.add(ZoneId.from("prod", "us-central-1")); + zones.add(ZoneId.from("prod", "eu-west-1")); } public ZoneRegistryMock setDeploymentTimeToLive(ZoneId zone, Duration duration) { @@ -54,6 +57,10 @@ public class ZoneRegistryMock extends AbstractComponent implements ZoneRegistry return this; } + public ZoneRegistryMock setZones(ZoneId... zone) { + return setZones(Arrays.asList(zone)); + } + public ZoneRegistryMock setSystem(SystemName system) { this.system = system; return this; diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/ApplicationPackageBuilder.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/ApplicationPackageBuilder.java index 3311cffa078..9d5fcb31288 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/ApplicationPackageBuilder.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/ApplicationPackageBuilder.java @@ -71,8 +71,8 @@ public class ApplicationPackageBuilder { return this; } - public ApplicationPackageBuilder blockChange(boolean revision, boolean version, - String daySpec, String hourSpec, String zoneSpec) { + public ApplicationPackageBuilder blockChange(boolean revision, boolean version, String daySpec, String hourSpec, + String zoneSpec) { blockChange.append(" <block-change"); blockChange.append(" revision='").append(revision).append("'"); blockChange.append(" version='").append(version).append("'"); @@ -93,7 +93,8 @@ public class ApplicationPackageBuilder { } public ApplicationPackageBuilder athenzIdentity(AthenzDomain domain, AthenzService service) { - this.athenzIdentityAttributes = String.format("athenz-domain='%s' athenz-service='%s'", domain.value(), service.value()); + this.athenzIdentityAttributes = String.format("athenz-domain='%s' athenz-service='%s'", domain.value(), + service.value()); return this; } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTester.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTester.java index c9f0c6cba1d..d0dfe825558 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTester.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTester.java @@ -7,6 +7,7 @@ import com.yahoo.config.provision.Environment; import com.yahoo.test.ManualClock; import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.ApplicationController; +import com.yahoo.vespa.hosted.controller.ArtifactRepositoryMock; import com.yahoo.vespa.hosted.controller.ConfigServerClientMock; import com.yahoo.vespa.hosted.controller.Controller; import com.yahoo.vespa.hosted.controller.ControllerTester; @@ -16,8 +17,9 @@ import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobType; -import com.yahoo.vespa.hosted.controller.maintenance.ReadyJobsTrigger; +import com.yahoo.vespa.hosted.controller.application.SourceRevision; import com.yahoo.vespa.hosted.controller.maintenance.JobControl; +import com.yahoo.vespa.hosted.controller.maintenance.ReadyJobsTrigger; import com.yahoo.vespa.hosted.controller.maintenance.Upgrader; import com.yahoo.vespa.hosted.controller.versions.VersionStatus; @@ -43,6 +45,7 @@ public class DeploymentTester { // Set a long interval so that maintainers never do scheduled runs during tests private static final Duration maintenanceInterval = Duration.ofDays(1); + private static final int defaultBuildNumber = 42; private final ControllerTester tester; private final Upgrader upgrader; @@ -81,6 +84,8 @@ public class DeploymentTester { public ConfigServerClientMock configServer() { return tester.configServer(); } + public ArtifactRepositoryMock artifactRepository() { return tester.artifactRepository(); } + public Application application(String name) { return application(ApplicationId.from("tenant1", name, "default")); } @@ -89,12 +94,6 @@ public class DeploymentTester { return controller().applications().require(application); } - public Optional<Change.VersionChange> versionChange(ApplicationId application) { - return application(application).deploying() - .filter(c -> c instanceof Change.VersionChange) - .map(Change.VersionChange.class::cast); - } - public void updateVersionStatus() { controller().updateVersionStatus(VersionStatus.compute(controller(), tester.controller().systemVersion())); } @@ -149,21 +148,23 @@ public class DeploymentTester { /** Deploy application completely using the given application package */ public void deployCompletely(Application application, ApplicationPackage applicationPackage) { notifyJobCompletion(JobType.component, application, true); - assertTrue(applications().require(application.id()).deploying().isPresent()); + assertTrue(applications().require(application.id()).change().isPresent()); completeDeployment(application, applicationPackage, Optional.empty(), true); } public static DeploymentJobs.JobReport jobReport(Application application, JobType jobType, boolean success) { - return jobReport(application, jobType, Optional.ofNullable(success ? null : unknown)); + return jobReport(application, jobType, Optional.ofNullable(success ? null : unknown), Optional.empty(), defaultBuildNumber); } - public static DeploymentJobs.JobReport jobReport(Application application, JobType jobType, Optional<DeploymentJobs.JobError> jobError) { + public static DeploymentJobs.JobReport jobReport(Application application, JobType jobType, + Optional<DeploymentJobs.JobError> jobError, + Optional<SourceRevision> sourceRevision, long buildNumber) { return new DeploymentJobs.JobReport( application.id(), jobType, application.deploymentJobs().projectId().get(), - 42, - Optional.empty(), + buildNumber, + sourceRevision, jobError ); } @@ -171,7 +172,7 @@ public class DeploymentTester { /** Deploy application using the given application package, but expecting to stop after test phases */ public void deployTestOnly(Application application, ApplicationPackage applicationPackage) { notifyJobCompletion(JobType.component, application, true); - assertTrue(applications().require(application.id()).deploying().isPresent()); + assertTrue(applications().require(application.id()).change().isPresent()); completeDeployment(application, applicationPackage, Optional.empty(), false); } @@ -189,13 +190,13 @@ public class DeploymentTester { } } if (failOnJob.isPresent()) { - assertTrue(applications().require(application.id()).deploying().isPresent()); + assertTrue(applications().require(application.id()).change().isPresent()); assertTrue(applications().require(application.id()).deploymentJobs().hasFailures()); } else if (includingProductionZones) { - assertFalse(applications().require(application.id()).deploying().isPresent()); + assertFalse(applications().require(application.id()).change().isPresent()); } else { - assertTrue(applications().require(application.id()).deploying().isPresent()); + assertTrue(applications().require(application.id()).change().isPresent()); } } @@ -204,8 +205,13 @@ public class DeploymentTester { } public void notifyJobCompletion(JobType jobType, Application application, Optional<DeploymentJobs.JobError> jobError) { + notifyJobCompletion(jobType, application, jobError, Optional.empty(), defaultBuildNumber); + } + + public void notifyJobCompletion(JobType jobType, Application application, Optional<DeploymentJobs.JobError> jobError, + Optional<SourceRevision> source, long buildNumber) { clock().advance(Duration.ofMillis(1)); - applications().notifyJobCompletion(jobReport(application, jobType, jobError)); + applications().notifyJobCompletion(jobReport(application, jobType, jobError, source, buildNumber)); } public void completeUpgrade(Application application, Version version, String upgradePolicy) { @@ -213,8 +219,8 @@ public class DeploymentTester { } public void completeUpgrade(Application application, Version version, ApplicationPackage applicationPackage) { - assertTrue(application + " has a deployment", applications().require(application.id()).deploying().isPresent()); - assertEquals(new Change.VersionChange(version), applications().require(application.id()).deploying().get()); + assertTrue(application + " has a deployment", applications().require(application.id()).change().isPresent()); + assertEquals(Change.of(version), applications().require(application.id()).change()); completeDeployment(application, applicationPackage, Optional.empty(), true); } @@ -227,17 +233,28 @@ public class DeploymentTester { } private void completeUpgradeWithError(Application application, Version version, ApplicationPackage applicationPackage, Optional<JobType> failOnJob) { - assertTrue(applications().require(application.id()).deploying().isPresent()); - assertEquals(new Change.VersionChange(version), applications().require(application.id()).deploying().get()); + assertTrue(applications().require(application.id()).change().isPresent()); + assertEquals(Change.of(version), applications().require(application.id()).change()); completeDeployment(application, applicationPackage, failOnJob, true); } public void deploy(JobType job, Application application, ApplicationPackage applicationPackage) { - deploy(job, application, applicationPackage, false); + deploy(job, application, Optional.of(applicationPackage), false); } - public void deploy(JobType job, Application application, ApplicationPackage applicationPackage, boolean deployCurrentVersion) { - job.zone(controller().system()).ifPresent(zone -> tester.deploy(application, zone, applicationPackage, deployCurrentVersion)); + public void deploy(JobType job, Application application, ApplicationPackage applicationPackage, + boolean deployCurrentVersion) { + deploy(job, application, Optional.of(applicationPackage), deployCurrentVersion); + } + + public void deploy(JobType job, Application application, Optional<ApplicationPackage> applicationPackage, + boolean deployCurrentVersion) { + job.zone(controller().system()).ifPresent(zone -> tester.deploy(application, zone, applicationPackage, + deployCurrentVersion)); + } + + public void deployAndNotify(Application application, boolean success, JobType... job) { + deployAndNotify(application, Optional.empty(), success, true, job); } public void deployAndNotify(Application application, String upgradePolicy, boolean success, JobType... jobs) { @@ -251,10 +268,15 @@ public class DeploymentTester { public void deployAndNotify(Application application, ApplicationPackage applicationPackage, boolean success, boolean expectOnlyTheseJobs, JobType... jobs) { + deployAndNotify(application, Optional.of(applicationPackage), success, expectOnlyTheseJobs, jobs); + } + + public void deployAndNotify(Application application, Optional<ApplicationPackage> applicationPackage, + boolean success, boolean expectOnlyTheseJobs, JobType... jobs) { consumeJobs(application, expectOnlyTheseJobs, jobs); for (JobType job : jobs) { if (success) { - deploy(job, application, applicationPackage); + deploy(job, application, applicationPackage, false); } notifyJobCompletion(job, application, success); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java index b71a9090c79..5c61e43f9cf 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java @@ -13,13 +13,12 @@ import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobType; -import com.yahoo.vespa.hosted.controller.maintenance.ReadyJobsTrigger; import com.yahoo.vespa.hosted.controller.maintenance.JobControl; +import com.yahoo.vespa.hosted.controller.maintenance.ReadyJobsTrigger; import org.junit.Test; import java.time.Duration; import java.time.Instant; -import java.util.Optional; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -221,13 +220,13 @@ public class DeploymentTriggerTest { tester.deploy(DeploymentJobs.JobType.productionUsWest1, app, applicationPackage); tester.notifyJobCompletion(DeploymentJobs.JobType.productionUsWest1, app, true); assertTrue("Change is present as not all jobs are complete", - tester.applications().require(app.id()).deploying().isPresent()); + tester.applications().require(app.id()).change().isPresent()); // All jobs complete tester.deploy(DeploymentJobs.JobType.productionUsEast3, app, applicationPackage); tester.notifyJobCompletion(DeploymentJobs.JobType.productionUsEast3, app, true); assertFalse("Change has been deployed", - tester.applications().require(app.id()).deploying().isPresent()); + tester.applications().require(app.id()).change().isPresent()); } @Test @@ -247,7 +246,7 @@ public class DeploymentTriggerTest { .region("corp-us-east-1") .region("us-central-1") .region("us-west-1") - .region("ap-northeast-1") + .region("eu-west-1") .build(); // Component job finishes @@ -260,7 +259,7 @@ public class DeploymentTriggerTest { tester.deployAndNotify(application, newApplicationPackage, true, JobType.productionCorpUsEast1); tester.deployAndNotify(application, newApplicationPackage, true, JobType.productionUsCentral1); tester.deployAndNotify(application, newApplicationPackage, true, JobType.productionUsWest1); - tester.deployAndNotify(application, newApplicationPackage, true, JobType.productionApNortheast1); + tester.deployAndNotify(application, newApplicationPackage, true, JobType.productionEuWest1); assertTrue("All jobs consumed", buildSystem.jobs().isEmpty()); } @@ -277,7 +276,7 @@ public class DeploymentTriggerTest { ApplicationPackageBuilder applicationPackageBuilder = new ApplicationPackageBuilder() .upgradePolicy("canary") - // Block revision changes on tuesday in hours 18 and 19 + // Block application version changes on tuesday in hours 18 and 19 .blockChange(true, false, "tue", "18-19", "UTC") .region("us-west-1") .region("us-central-1") @@ -321,7 +320,7 @@ public class DeploymentTriggerTest { new JobControl(tester.controllerTester().curator())); LockedApplication app = (LockedApplication)tester.createAndDeploy("default0", 3, "default"); // Store that we are upgrading but don't start the system-tests job - tester.controller().applications().store(app.withDeploying(Optional.of(new Change.VersionChange(Version.fromString("6.2"))))); + tester.controller().applications().store(app.withDeploying(Change.of(Version.fromString("6.2")))); assertEquals(0, tester.buildSystem().jobs().size()); readyJobsTrigger.run(); assertEquals(1, tester.buildSystem().jobs().size()); @@ -350,7 +349,7 @@ public class DeploymentTriggerTest { // Extra notification for last job tester.notifyJobCompletion(JobType.productionCorpUsEast1, application, true); assertFalse("Change has been deployed", - tester.applications().require(application.id()).deploying().isPresent()); + tester.applications().require(application.id()).change().isPresent()); assertTrue("All jobs consumed", buildSystem.jobs().isEmpty()); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmerTest.java index b0954044c22..c775dd3fd7c 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationOwnershipConfirmerTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.maintenance; import com.yahoo.config.provision.ApplicationId; diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirerTest.java index d48f7b84ee6..513e5520d85 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirerTest.java @@ -3,8 +3,8 @@ package com.yahoo.vespa.hosted.controller.maintenance; import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.RegionName; -import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.Application; +import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; import com.yahoo.vespa.hosted.controller.application.Deployment; import com.yahoo.vespa.hosted.controller.deployment.ApplicationPackageBuilder; @@ -13,7 +13,6 @@ import com.yahoo.vespa.hosted.controller.persistence.MockCuratorDb; import org.junit.Before; import org.junit.Test; -import java.io.IOException; import java.time.Duration; import java.util.List; import java.util.stream.Collectors; @@ -33,7 +32,7 @@ public class DeploymentExpirerTest { } @Test - public void testDeploymentExpiry() throws IOException, InterruptedException { + public void testDeploymentExpiry() { tester.controllerTester().zoneRegistry().setDeploymentTimeToLive( ZoneId.from(Environment.dev, RegionName.from("us-east-1")), Duration.ofDays(14) diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/FailureRedeployerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/FailureRedeployerTest.java index 62e6d379c60..11e85c6be9f 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/FailureRedeployerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/FailureRedeployerTest.java @@ -7,6 +7,7 @@ import com.yahoo.config.provision.SystemName; import com.yahoo.slime.Slime; import com.yahoo.vespa.config.SlimeUtils; import com.yahoo.vespa.hosted.controller.Application; +import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs; import com.yahoo.vespa.hosted.controller.deployment.ApplicationPackageBuilder; @@ -57,7 +58,7 @@ public class FailureRedeployerTest { tester.clock().advance(Duration.ofSeconds(1)); // Advance time so that we can detect jobs in progress tester.deployAndNotify(app, applicationPackage, false, DeploymentJobs.JobType.productionUsEast3); assertEquals("Production job is retried", 1, tester.buildSystem().jobs().size()); - assertEquals("Application has pending upgrade to " + version, version, tester.versionChange(app.id()).get().version()); + assertEquals("Application has pending upgrade to " + version, version, tester.application(app.id()).change().platform().get()); // Another version is released, which cancels any pending upgrades to lower versions version = Version.fromString("5.2"); @@ -65,7 +66,7 @@ public class FailureRedeployerTest { tester.deployAndNotify(app, applicationPackage, true, DeploymentJobs.JobType.productionUsEast3); // Finish previous production job. tester.upgrader().maintain(); assertEquals("Application starts upgrading to new version", 1, tester.buildSystem().jobs().size()); - assertEquals("Application has pending upgrade to " + version, version, tester.versionChange(app.id()).get().version()); + assertEquals("Application has pending upgrade to " + version, version, tester.application(app.id()).change().platform().get()); // Failure redeployer does not retry failing job for prod.us-east-3 as there's an ongoing deployment tester.clock().advance(Duration.ofMinutes(1)); @@ -149,7 +150,7 @@ public class FailureRedeployerTest { tester.updateVersionStatus(version); assertEquals(version, tester.controller().versionStatus().systemVersion().get().versionNumber()); tester.upgrader().maintain(); - assertEquals("Application has pending upgrade to " + version, version, tester.versionChange(app.id()).get().version()); + assertEquals("Application has pending upgrade to " + version, version, tester.application(app.id()).change().platform().get()); // system-test fails and exhausts all immediate retries tester.deployAndNotify(app, applicationPackage, false, DeploymentJobs.JobType.systemTest); @@ -163,7 +164,7 @@ public class FailureRedeployerTest { tester.updateVersionStatus(version); assertEquals(version, tester.controller().versionStatus().systemVersion().get().versionNumber()); tester.upgrader().maintain(); - assertEquals("Application has pending upgrade to " + version, version, tester.versionChange(app.id()).get().version()); + assertEquals("Application has pending upgrade to " + version, version, tester.application(app.id()).change().platform().get()); // Consume system-test job for 5.2 tester.buildSystem().takeJobsToRun(); @@ -178,6 +179,7 @@ public class FailureRedeployerTest { public void retryIgnoresStaleJobData() throws Exception { DeploymentTester tester = new DeploymentTester(); tester.controllerTester().zoneRegistry().setSystem(SystemName.cd); + tester.controllerTester().zoneRegistry().setZones(ZoneId.from("prod", "cd-us-central-1")); // Current system version, matches version in test data Version version = Version.fromString("6.141.117"); @@ -223,12 +225,12 @@ public class FailureRedeployerTest { // Deployment notifies completeness but has not actually made a deployment tester.notifyJobCompletion(DeploymentJobs.JobType.productionCdUsCentral1, application, true); - assertTrue("Change not really deployed", tester.application(application.id()).deploying().isPresent()); + assertTrue("Change not really deployed", tester.application(application.id()).change().isPresent()); // Deployment actually deploys and notifies completeness tester.deploy(DeploymentJobs.JobType.productionCdUsCentral1, application, applicationPackage); tester.notifyJobCompletion(DeploymentJobs.JobType.productionCdUsCentral1, application, true); - assertFalse("Change not really deployed", tester.application(application.id()).deploying().isPresent()); + assertFalse("Change not really deployed", tester.application(application.id()).change().isPresent()); } @Test diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/OutstandingChangeDeployerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/OutstandingChangeDeployerTest.java index 13636122cfd..3d34e78c759 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/OutstandingChangeDeployerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/OutstandingChangeDeployerTest.java @@ -20,28 +20,28 @@ import static org.junit.Assert.assertTrue; * @author bratseth */ public class OutstandingChangeDeployerTest { - + @Test public void testChangeDeployer() { DeploymentTester tester = new DeploymentTester(); tester.configServer().setDefaultVersion(new Version(6, 1)); - OutstandingChangeDeployer deployer = new OutstandingChangeDeployer(tester.controller(), Duration.ofMinutes(10), + OutstandingChangeDeployer deployer = new OutstandingChangeDeployer(tester.controller(), Duration.ofMinutes(10), new JobControl(new MockCuratorDb())); tester.createAndDeploy("app1", 11, "default"); tester.createAndDeploy("app2", 22, "default"); Version version = new Version(6, 2); - tester.deploymentTrigger().triggerChange(tester.application("app1").id(), new Change.VersionChange(version)); - - assertEquals(new Change.VersionChange(version), tester.application("app1").deploying().get()); + tester.deploymentTrigger().triggerChange(tester.application("app1").id(), Change.of(version)); + + assertEquals(Change.of(version), tester.application("app1").change()); assertFalse(tester.application("app1").hasOutstandingChange()); tester.notifyJobCompletion(DeploymentJobs.JobType.component, tester.application("app1"), true); assertTrue(tester.application("app1").hasOutstandingChange()); assertEquals(1, tester.buildSystem().jobs().size()); - + deployer.maintain(); assertEquals("No effect as job is in progress", 1, tester.buildSystem().jobs().size()); - + tester.deployCompletely("app1"); assertEquals("Upgrade done", 0, tester.buildSystem().jobs().size()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/UpgraderTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/UpgraderTest.java index fd13b99f25e..ccc029a9654 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/UpgraderTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/UpgraderTest.java @@ -9,7 +9,6 @@ import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.ControllerTester; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; -import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.Deployment; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs; import com.yahoo.vespa.hosted.controller.deployment.ApplicationPackageBuilder; @@ -174,11 +173,11 @@ public class UpgraderTest { assertEquals(VespaVersion.Confidence.normal, tester.controller().versionStatus().systemVersion().get().confidence()); tester.upgrader().maintain(); assertEquals("Upgrade of defaults are scheduled", 5, tester.buildSystem().jobs().size()); - assertEquals(version54, ((Change.VersionChange)tester.application(default0.id()).deploying().get()).version()); - assertEquals(version54, ((Change.VersionChange)tester.application(default1.id()).deploying().get()).version()); - assertEquals(version54, ((Change.VersionChange)tester.application(default2.id()).deploying().get()).version()); - assertEquals(version54, ((Change.VersionChange)tester.application(default3.id()).deploying().get()).version()); - assertEquals(version54, ((Change.VersionChange)tester.application(default4.id()).deploying().get()).version()); + assertEquals(version54, tester.application(default0.id()).change().platform().get()); + assertEquals(version54, tester.application(default1.id()).change().platform().get()); + assertEquals(version54, tester.application(default2.id()).change().platform().get()); + assertEquals(version54, tester.application(default3.id()).change().platform().get()); + assertEquals(version54, tester.application(default4.id()).change().platform().get()); tester.completeUpgrade(default0, version54, "default"); // State: Default applications started upgrading to 5.4 (and one completed) Version version55 = Version.fromString("5.5"); @@ -190,11 +189,11 @@ public class UpgraderTest { assertEquals(VespaVersion.Confidence.normal, tester.controller().versionStatus().systemVersion().get().confidence()); tester.upgrader().maintain(); assertEquals("Upgrade of defaults are scheduled", 5, tester.buildSystem().jobs().size()); - assertEquals(version55, ((Change.VersionChange)tester.application(default0.id()).deploying().get()).version()); - assertEquals(version54, ((Change.VersionChange)tester.application(default1.id()).deploying().get()).version()); - assertEquals(version54, ((Change.VersionChange)tester.application(default2.id()).deploying().get()).version()); - assertEquals(version54, ((Change.VersionChange)tester.application(default3.id()).deploying().get()).version()); - assertEquals(version54, ((Change.VersionChange)tester.application(default4.id()).deploying().get()).version()); + assertEquals(version55, tester.application(default0.id()).change().platform().get()); + assertEquals(version54, tester.application(default1.id()).change().platform().get()); + assertEquals(version54, tester.application(default2.id()).change().platform().get()); + assertEquals(version54, tester.application(default3.id()).change().platform().get()); + assertEquals(version54, tester.application(default4.id()).change().platform().get()); tester.completeUpgrade(default1, version54, "default"); tester.completeUpgrade(default2, version54, "default"); tester.completeUpgradeWithError(default3, version54, "default", DeploymentJobs.JobType.stagingTest); @@ -216,7 +215,7 @@ public class UpgraderTest { assertEquals("Upgrade of defaults are scheduled on 5.4 instead, since 5.5 broken: " + "This is default3 since it failed upgrade on both 5.4 and 5.5", 1, tester.buildSystem().jobs().size()); - assertEquals("5.4", ((Change.VersionChange)tester.application(default3.id()).deploying().get()).version().toString()); + assertEquals("5.4", tester.application(default3.id()).change().platform().get().toString()); } @Test @@ -320,7 +319,7 @@ public class UpgraderTest { tester.notifyJobCompletion(DeploymentJobs.JobType.stagingTest, app, false); assertTrue("Retries exhausted", tester.buildSystem().jobs().isEmpty()); assertTrue("Failure is recorded", tester.application(app.id()).deploymentJobs().hasFailures()); - assertTrue("Application has pending change", tester.application(app.id()).deploying().isPresent()); + assertTrue("Application has pending change", tester.application(app.id()).change().isPresent()); // New version is released version = Version.fromString("5.2"); @@ -379,7 +378,7 @@ public class UpgraderTest { tester.upgrader().maintain(); // 5th app passes system-test, but does not trigger next job as upgrade is cancelled - assertFalse("No change present", tester.applications().require(default4.id()).deploying().isPresent()); + assertFalse("No change present", tester.applications().require(default4.id()).change().isPresent()); tester.notifyJobCompletion(DeploymentJobs.JobType.systemTest, default4, true); assertTrue("All jobs consumed", tester.buildSystem().jobs().isEmpty()); } @@ -476,7 +475,7 @@ public class UpgraderTest { assertEquals(v2, tester.application("default0").deployments().get(ZoneId.from("prod.us-west-1")).version()); assertEquals("Last zone is upgraded to v1", v1, tester.application("default0").deployments().get(ZoneId.from("prod.us-east-3")).version()); - assertFalse(tester.application("default0").deploying().isPresent()); + assertFalse(tester.application("default0").change().isPresent()); } @Test @@ -747,7 +746,7 @@ public class UpgraderTest { // 5th app never reports back and has a dead job, but no ongoing change Application deadLocked = tester.applications().require(default4.id()); assertTrue("Jobs in progress", deadLocked.deploymentJobs().isRunning(tester.controller().applications().deploymentTrigger().jobTimeoutLimit())); - assertFalse("No change present", deadLocked.deploying().isPresent()); + assertFalse("No change present", deadLocked.change().isPresent()); // 4 out of 5 applications are repaired and confidence is restored ApplicationPackage defaultApplicationPackageV2 = new ApplicationPackageBuilder() diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java index d7389ca94cd..ffb9ee57351 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java @@ -6,14 +6,14 @@ import com.yahoo.config.application.api.DeploymentSpec; import com.yahoo.config.application.api.ValidationOverrides; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterSpec; -import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.slime.Slime; import com.yahoo.vespa.config.SlimeUtils; import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.ControllerTester; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService; import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId; -import com.yahoo.vespa.hosted.controller.application.ApplicationRevision; +import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; +import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.ClusterInfo; import com.yahoo.vespa.hosted.controller.application.ClusterUtilization; @@ -26,7 +26,6 @@ import com.yahoo.vespa.hosted.controller.application.SourceRevision; import com.yahoo.vespa.hosted.controller.rotation.RotationId; import org.junit.Test; -import java.io.IOException; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.ArrayList; @@ -60,20 +59,21 @@ public class ApplicationSerializerTest { "</validation-overrides>"); List<Deployment> deployments = new ArrayList<>(); - ApplicationRevision revision1 = ApplicationRevision.from("appHash1"); - ApplicationRevision revision2 = ApplicationRevision.from("appHash2", new SourceRevision("repo1", "branch1", "commit1")); - deployments.add(new Deployment(zone1, revision1, Version.fromString("1.2.3"), Instant.ofEpochMilli(3))); // One deployment without cluster info and utils - deployments.add(new Deployment(zone2, revision2, Version.fromString("1.2.3"), Instant.ofEpochMilli(5), + ApplicationVersion applicationVersion1 = ApplicationVersion.from("appHash1"); + ApplicationVersion applicationVersion2 = ApplicationVersion + .from("appHash2", new SourceRevision("repo1", "branch1", "commit1")); + deployments.add(new Deployment(zone1, applicationVersion1, Version.fromString("1.2.3"), Instant.ofEpochMilli(3))); // One deployment without cluster info and utils + deployments.add(new Deployment(zone2, applicationVersion2, Version.fromString("1.2.3"), Instant.ofEpochMilli(5), createClusterUtils(3, 0.2), createClusterInfo(3, 4),new DeploymentMetrics(2,3,4,5,6))); Optional<Long> projectId = Optional.of(123L); List<JobStatus> statusList = new ArrayList<>(); statusList.add(JobStatus.initial(DeploymentJobs.JobType.systemTest) - .withTriggering(Version.fromString("5.6.7"), Optional.empty(), true, "Test", Instant.ofEpochMilli(7)) + .withTriggering(Version.fromString("5.6.7"), ApplicationVersion.unknown, true, "Test", Instant.ofEpochMilli(7)) .withCompletion(30, Optional.empty(), Instant.ofEpochMilli(8), tester.controller())); statusList.add(JobStatus.initial(DeploymentJobs.JobType.stagingTest) - .withTriggering(Version.fromString("5.6.6"), Optional.empty(), true, "Test 2", Instant.ofEpochMilli(5)) + .withTriggering(Version.fromString("5.6.6"), ApplicationVersion.unknown, true, "Test 2", Instant.ofEpochMilli(5)) .withCompletion(11, Optional.of(JobError.unknown), Instant.ofEpochMilli(6), tester.controller())); DeploymentJobs deploymentJobs = new DeploymentJobs(projectId, statusList, Optional.empty()); @@ -82,7 +82,7 @@ public class ApplicationSerializerTest { deploymentSpec, validationOverrides, deployments, deploymentJobs, - Optional.of(new Change.VersionChange(Version.fromString("6.7"))), + Change.of(Version.fromString("6.7")), true, Optional.of(IssueId.from("1234")), new MetricsService.ApplicationMetrics(0.5, 0.9), @@ -96,8 +96,8 @@ public class ApplicationSerializerTest { assertEquals(original.validationOverrides().xmlForm(), serialized.validationOverrides().xmlForm()); assertEquals(2, serialized.deployments().size()); - assertEquals(original.deployments().get(zone1).revision(), serialized.deployments().get(zone1).revision()); - assertEquals(original.deployments().get(zone2).revision(), serialized.deployments().get(zone2).revision()); + assertEquals(original.deployments().get(zone1).applicationVersion(), serialized.deployments().get(zone1).applicationVersion()); + assertEquals(original.deployments().get(zone2).applicationVersion(), serialized.deployments().get(zone2).applicationVersion()); assertEquals(original.deployments().get(zone1).version(), serialized.deployments().get(zone1).version()); assertEquals(original.deployments().get(zone2).version(), serialized.deployments().get(zone2).version()); assertEquals(original.deployments().get(zone1).at(), serialized.deployments().get(zone1).at()); @@ -114,7 +114,7 @@ public class ApplicationSerializerTest { assertEquals(original.ownershipIssueId(), serialized.ownershipIssueId()); - assertEquals(original.deploying(), serialized.deploying()); + assertEquals(original.change(), serialized.change()); assertEquals(original.rotation().get().id(), serialized.rotation().get().id()); // Test cluster utilization @@ -145,22 +145,22 @@ public class ApplicationSerializerTest { assertEquals(6, serialized.deployments().get(zone2).metrics().writeLatencyMillis(), Double.MIN_VALUE); { // test more deployment serialization cases - Application original2 = writable(original).withDeploying(Optional.of(Change.ApplicationChange.of(ApplicationRevision.from("hash1")))); + Application original2 = writable(original).withDeploying(Change.of(ApplicationVersion.from("hash1"))); Application serialized2 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original2)); - assertEquals(original2.deploying(), serialized2.deploying()); - assertEquals(((Change.ApplicationChange)serialized2.deploying().get()).revision().get().source(), - ((Change.ApplicationChange)original2.deploying().get()).revision().get().source()); + assertEquals(original2.change(), serialized2.change()); + assertEquals(serialized2.change().application().get().source(), + original2.change().application().get().source()); - Application original3 = writable(original).withDeploying(Optional.of(Change.ApplicationChange.of(ApplicationRevision.from("hash1", - new SourceRevision("a", "b", "c"))))); + Application original3 = writable(original).withDeploying(Change.of(ApplicationVersion.from("hash1", + new SourceRevision("a", "b", "c")))); Application serialized3 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original3)); - assertEquals(original3.deploying(), serialized2.deploying()); - assertEquals(((Change.ApplicationChange)serialized3.deploying().get()).revision().get().source(), - ((Change.ApplicationChange)original3.deploying().get()).revision().get().source()); + assertEquals(original3.change(), serialized2.change()); + assertEquals(serialized3.change().application().get().source(), + original3.change().application().get().source()); - Application original4 = writable(original).withDeploying(Optional.empty()); + Application original4 = writable(original).withDeploying(Change.empty()); Application serialized4 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original4)); - assertEquals(original4.deploying(), serialized4.deploying()); + assertEquals(original4.change(), serialized4.change()); } } @@ -195,7 +195,7 @@ public class ApplicationSerializerTest { } @Test - public void testLegacySerialization() throws IOException { + public void testLegacySerialization() { Application applicationWithSuccessfulJob = applicationSerializer.fromSlime(applicationSlime(false)); assertFalse("No job error for successful job", applicationWithSuccessfulJob.deploymentJobs().jobStatus().get(DeploymentJobs.JobType.systemTest).jobError().isPresent()); @@ -208,10 +208,10 @@ public class ApplicationSerializerTest { Application application = applicationSerializer.fromSlime(applicationSlime(false)); assertFalse(application.deploymentJobs().jobStatus().get(DeploymentJobs.JobType.systemTest).lastCompleted().get().upgrade()); } - + @Test public void testCompleteApplicationDeserialization() { - Application application = applicationSerializer.fromSlime(SlimeUtils.jsonToSlime(longApplicationJson.getBytes(StandardCharsets.UTF_8))); + applicationSerializer.fromSlime(SlimeUtils.jsonToSlime(longApplicationJson.getBytes(StandardCharsets.UTF_8))); // ok if no error } @@ -251,6 +251,6 @@ public class ApplicationSerializerTest { " }\n" + "}\n"; } - + private final String longApplicationJson = "{\"id\":\"tripod:service-aggregation-vespa:default\",\"deploymentSpecField\":\"<deployment version='1.0'>\\n <test />\\n <!--<staging />-->\\n <prod global-service-id=\\\"tripod\\\">\\n <region active=\\\"true\\\">us-east-3</region>\\n <region active=\\\"true\\\">us-west-1</region>\\n </prod>\\n</deployment>\\n\",\"validationOverrides\":\"<validation-overrides>\\n <allow until=\\\"2016-04-28\\\" comment=\\\"Renaming content cluster\\\">content-cluster-removal</allow>\\n <allow until=\\\"2016-08-22\\\" comment=\\\"Migrating us-east-3 to C-2E\\\">cluster-size-reduction</allow>\\n <allow until=\\\"2017-06-30\\\" comment=\\\"Test Vespa upgrade tests\\\">force-automatic-tenant-upgrade-test</allow>\\n</validation-overrides>\\n\",\"deployments\":[{\"zone\":{\"environment\":\"prod\",\"region\":\"us-west-1\"},\"version\":\"6.173.62\",\"deployTime\":1510837817704,\"applicationPackageRevision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"clusterInfo\":{\"tripod\":{\"flavor\":\"d-3-16-100\",\"cost\":9,\"flavorCpu\":0,\"flavorMem\":0,\"flavorDisk\":0,\"clusterType\":\"container\",\"hostnames\":[\"oxy-oxygen-2001-4998-c-2942--10d1.gq1.yahoo.com\",\"oxy-oxygen-2001-4998-c-2942--10e2.gq1.yahoo.com\"]},\"tripodaggregation\":{\"flavor\":\"d-12-64-400\",\"cost\":38,\"flavorCpu\":0,\"flavorMem\":0,\"flavorDisk\":0,\"clusterType\":\"content\",\"hostnames\":[\"oxy-oxygen-2001-4998-c-2941--106a.gq1.yahoo.com\",\"zt74700-v6-23.ostk.bm2.prod.gq1.yahoo.com\",\"zt74714-v6-28.ostk.bm2.prod.gq1.yahoo.com\",\"zt74730-v6-13.ostk.bm2.prod.gq1.yahoo.com\",\"zt74717-v6-7.ostk.bm2.prod.gq1.yahoo.com\",\"2080260-v6-12.ostk.bm2.prod.gq1.yahoo.com\",\"zt74719-v6-23.ostk.bm2.prod.gq1.yahoo.com\",\"zt74722-v6-26.ostk.bm2.prod.gq1.yahoo.com\",\"zt74704-v6-9.ostk.bm2.prod.gq1.yahoo.com\",\"oxy-oxygen-2001-4998-c-2942--107d.gq1.yahoo.com\"]},\"tripodaggregationstream\":{\"flavor\":\"d-12-64-400\",\"cost\":38,\"flavorCpu\":0,\"flavorMem\":0,\"flavorDisk\":0,\"clusterType\":\"content\",\"hostnames\":[\"zt74727-v6-21.ostk.bm2.prod.gq1.yahoo.com\",\"zt74773-v6-8.ostk.bm2.prod.gq1.yahoo.com\",\"zt74699-v6-25.ostk.bm2.prod.gq1.yahoo.com\",\"zt74766-v6-27.ostk.bm2.prod.gq1.yahoo.com\"]}},\"clusterUtils\":{\"tripod\":{\"cpu\":0.1720353499228221,\"mem\":0.4986146831512451,\"disk\":0.0617671330041831,\"diskbusy\":0},\"tripodaggregation\":{\"cpu\":0.07505730001866318,\"mem\":0.7936344432830811,\"disk\":0.2260549694485994,\"diskbusy\":0},\"tripodaggregationstream\":{\"cpu\":0.01712671480989384,\"mem\":0.0225852754983035,\"disk\":0.006084436856721915,\"diskbusy\":0}},\"metrics\":{\"queriesPerSecond\":1.25,\"writesPerSecond\":43.83199977874756,\"documentCount\":525880277.9999999,\"queryLatencyMillis\":5.607503938674927,\"writeLatencyMillis\":20.57866265104621}},{\"zone\":{\"environment\":\"test\",\"region\":\"us-east-1\"},\"version\":\"6.173.62\",\"deployTime\":1511256872316,\"applicationPackageRevision\":{\"applicationPackageHash\":\"ec548fa61cbfab7a270a51d46b1263ec1be5d9a8\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"234f3e4e77049d0b9538c9e1b356d17eb1dedb6a\"}},\"clusterInfo\":{},\"clusterUtils\":{},\"metrics\":{\"queriesPerSecond\":0,\"writesPerSecond\":0,\"documentCount\":0,\"queryLatencyMillis\":0,\"writeLatencyMillis\":0}},{\"zone\":{\"environment\":\"dev\",\"region\":\"us-east-1\"},\"version\":\"6.173.62\",\"deployTime\":1510597489464,\"applicationPackageRevision\":{\"applicationPackageHash\":\"59b883f263c2a3c23dfab249730097d7e0e1ed32\"},\"clusterInfo\":{\"tripod\":{\"flavor\":\"d-2-8-50\",\"cost\":5,\"flavorCpu\":0,\"flavorMem\":0,\"flavorDisk\":0,\"clusterType\":\"container\",\"hostnames\":[\"zt40807-v6-29.ostk.bm2.prod.bf1.yahoo.com\"]},\"tripodaggregation\":{\"flavor\":\"d-2-8-50\",\"cost\":5,\"flavorCpu\":0,\"flavorMem\":0,\"flavorDisk\":0,\"clusterType\":\"content\",\"hostnames\":[\"zt40807-v6-24.ostk.bm2.prod.bf1.yahoo.com\"]},\"tripodaggregationstream\":{\"flavor\":\"d-2-8-50\",\"cost\":5,\"flavorCpu\":0,\"flavorMem\":0,\"flavorDisk\":0,\"clusterType\":\"content\",\"hostnames\":[\"zt40694-v6-21.ostk.bm2.prod.bf1.yahoo.com\"]}},\"clusterUtils\":{\"tripod\":{\"cpu\":0.191833330678661,\"mem\":0.4625738318415235,\"disk\":0.05582004563850269,\"diskbusy\":0},\"tripodaggregation\":{\"cpu\":0.2227037978608054,\"mem\":0.2051752598416401,\"disk\":0.05471533698695047,\"diskbusy\":0},\"tripodaggregationstream\":{\"cpu\":0.1869410834020498,\"mem\":0.1691722576000564,\"disk\":0.04977374774258153,\"diskbusy\":0}},\"metrics\":{\"queriesPerSecond\":0,\"writesPerSecond\":0,\"documentCount\":30916,\"queryLatencyMillis\":0,\"writeLatencyMillis\":0}},{\"zone\":{\"environment\":\"prod\",\"region\":\"us-east-3\"},\"version\":\"6.173.62\",\"deployTime\":1510817190016,\"applicationPackageRevision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"clusterInfo\":{\"tripod\":{\"flavor\":\"d-3-16-100\",\"cost\":9,\"flavorCpu\":0,\"flavorMem\":0,\"flavorDisk\":0,\"clusterType\":\"container\",\"hostnames\":[\"zt40738-v6-13.ostk.bm2.prod.bf1.yahoo.com\",\"zt40783-v6-31.ostk.bm2.prod.bf1.yahoo.com\"]},\"tripodaggregation\":{\"flavor\":\"d-12-64-400\",\"cost\":38,\"flavorCpu\":0,\"flavorMem\":0,\"flavorDisk\":0,\"clusterType\":\"content\",\"hostnames\":[\"zt40819-v6-7.ostk.bm2.prod.bf1.yahoo.com\",\"zt40661-v6-3.ostk.bm2.prod.bf1.yahoo.com\",\"zt40805-v6-30.ostk.bm2.prod.bf1.yahoo.com\",\"zt40702-v6-32.ostk.bm2.prod.bf1.yahoo.com\",\"zt40706-v6-3.ostk.bm2.prod.bf1.yahoo.com\",\"zt40691-v6-27.ostk.bm2.prod.bf1.yahoo.com\",\"zt40676-v6-15.ostk.bm2.prod.bf1.yahoo.com\",\"zt40788-v6-23.ostk.bm2.prod.bf1.yahoo.com\",\"zt40782-v6-30.ostk.bm2.prod.bf1.yahoo.com\",\"zt40802-v6-32.ostk.bm2.prod.bf1.yahoo.com\"]},\"tripodaggregationstream\":{\"flavor\":\"d-12-64-400\",\"cost\":38,\"flavorCpu\":0,\"flavorMem\":0,\"flavorDisk\":0,\"clusterType\":\"content\",\"hostnames\":[\"zt40779-v6-27.ostk.bm2.prod.bf1.yahoo.com\",\"zt40791-v6-15.ostk.bm2.prod.bf1.yahoo.com\",\"zt40733-v6-31.ostk.bm2.prod.bf1.yahoo.com\",\"zt40724-v6-30.ostk.bm2.prod.bf1.yahoo.com\"]}},\"clusterUtils\":{\"tripod\":{\"cpu\":0.2295038983007097,\"mem\":0.4627357390237263,\"disk\":0.05559941525894966,\"diskbusy\":0},\"tripodaggregation\":{\"cpu\":0.05340429087579549,\"mem\":0.8107630891552372,\"disk\":0.226444914138854,\"diskbusy\":0},\"tripodaggregationstream\":{\"cpu\":0.02148227413975218,\"mem\":0.02162174219104161,\"disk\":0.006057760545243265,\"diskbusy\":0}},\"metrics\":{\"queriesPerSecond\":1.734000012278557,\"writesPerSecond\":44.59999895095825,\"documentCount\":525868193.9999999,\"queryLatencyMillis\":5.65284947195106,\"writeLatencyMillis\":17.34593812832452}}],\"deploymentJobs\":{\"projectId\":102889,\"jobStatus\":[{\"jobType\":\"staging-test\",\"lastTriggered\":{\"id\":-1,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"upgrade\":true,\"reason\":\"system-test completed\",\"at\":1510830134259},\"lastCompleted\":{\"id\":1184,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"upgrade\":true,\"reason\":\"system-test completed\",\"at\":1510830684960},\"lastSuccess\":{\"id\":1184,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"upgrade\":true,\"reason\":\"system-test completed\",\"at\":1510830684960}},{\"jobType\":\"component\",\"lastCompleted\":{\"id\":849,\"version\":\"6.174.156\",\"upgrade\":false,\"reason\":\"Application commit\",\"at\":1511217733555},\"lastSuccess\":{\"id\":849,\"version\":\"6.174.156\",\"upgrade\":false,\"reason\":\"Application commit\",\"at\":1511217733555}},{\"jobType\":\"production-us-east-3\",\"lastTriggered\":{\"id\":-1,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"upgrade\":true,\"reason\":\"staging-test completed\",\"at\":1510830685127},\"lastCompleted\":{\"id\":923,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"upgrade\":true,\"reason\":\"staging-test completed\",\"at\":1510837650046},\"lastSuccess\":{\"id\":923,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"upgrade\":true,\"reason\":\"staging-test completed\",\"at\":1510837650046}},{\"jobType\":\"production-us-west-1\",\"lastTriggered\":{\"id\":-1,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"upgrade\":true,\"reason\":\"production-us-east-3 completed\",\"at\":1510837650139},\"lastCompleted\":{\"id\":646,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"upgrade\":true,\"reason\":\"production-us-east-3 completed\",\"at\":1510843559162},\"lastSuccess\":{\"id\":646,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"upgrade\":true,\"reason\":\"production-us-east-3 completed\",\"at\":1510843559162}},{\"jobType\":\"system-test\",\"jobError\":\"unknown\",\"lastTriggered\":{\"id\":-1,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"ec548fa61cbfab7a270a51d46b1263ec1be5d9a8\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"234f3e4e77049d0b9538c9e1b356d17eb1dedb6a\"}},\"upgrade\":false,\"reason\":\"Available change in component\",\"at\":1511256608649},\"lastCompleted\":{\"id\":1686,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"ec548fa61cbfab7a270a51d46b1263ec1be5d9a8\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"234f3e4e77049d0b9538c9e1b356d17eb1dedb6a\"}},\"upgrade\":false,\"reason\":\"Available change in component\",\"at\":1511256603353},\"firstFailing\":{\"id\":1659,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"ec548fa61cbfab7a270a51d46b1263ec1be5d9a8\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"234f3e4e77049d0b9538c9e1b356d17eb1dedb6a\"}},\"upgrade\":false,\"reason\":\"component completed\",\"at\":1511219070725},\"lastSuccess\":{\"id\":1658,\"version\":\"6.173.62\",\"revision\":{\"applicationPackageHash\":\"9db423e1021d7b452d37ec6372bc757d9c1bda87\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"49cd7bbb1ed9f4b922083cb042590b0885ffe22b\"}},\"upgrade\":true,\"reason\":\"Upgrading to 6.173.62\",\"at\":1511175754163}}]},\"deployingField\":{\"applicationPackageHash\":\"ec548fa61cbfab7a270a51d46b1263ec1be5d9a8\",\"sourceRevision\":{\"repositoryField\":\"git@git.ouroath.com:Tripod/service-aggregation-vespa.git\",\"branchField\":\"origin/master\",\"commitField\":\"234f3e4e77049d0b9538c9e1b356d17eb1dedb6a\"}},\"outstandingChangeField\":false,\"queryQuality\":100,\"writeQuality\":99.99894341115082}"; } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ContainerControllerTester.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ContainerControllerTester.java index fc0147dacef..5b806d580e2 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ContainerControllerTester.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ContainerControllerTester.java @@ -81,7 +81,7 @@ public class ContainerControllerTester { GitRevision app1RevisionId = new GitRevision(new GitRepository("repo"), new GitBranch("master"), new GitCommit("commit1")); controller().applications().deployApplication(application.id(), zone, - applicationPackage, + Optional.of(applicationPackage), new DeployOptions(Optional.of(new ScrewdriverBuildJob(app1ScrewdriverId, app1RevisionId)), Optional.empty(), false, false)); return application; } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java index 028992e8f7d..abc5f9f8aa1 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java @@ -67,6 +67,7 @@ public class ControllerContainerTest { " <component id='com.yahoo.vespa.hosted.controller.persistence.MemoryControllerDb'/>\n" + " <component id='com.yahoo.vespa.hosted.controller.restapi.application.MockAuthorizer'/>\n" + " <component id='com.yahoo.vespa.hosted.controller.routing.MockRoutingGenerator'/>\n" + + " <component id='com.yahoo.vespa.hosted.controller.ArtifactRepositoryMock'/>\n" + " <handler id='com.yahoo.vespa.hosted.controller.restapi.application.ApplicationApiHandler'>\n" + " <binding>http://*/application/v4/*</binding>\n" + " </handler>\n" + diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-deployment-cancelled.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-deployment-cancelled.json index d1e1ebe94fd..3b6d8ed71e9 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-deployment-cancelled.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-deployment-cancelled.json @@ -1 +1 @@ -{"message":"Cancelled version change to 6.1 for application 'tenant1.application1'"} +{"message":"Cancelled upgrade to 6.1 for application 'tenant1.application1'"} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/zone/v2/ZoneApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/zone/v2/ZoneApiTest.java index 782dc6dba4f..c52266dfacc 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/zone/v2/ZoneApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/zone/v2/ZoneApiTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.restapi.zone.v2; import com.yahoo.application.container.handler.Request; diff --git a/dist/build-rpm.sh b/dist/build-rpm.sh index e86eebe9380..5d6d2ba7809 100755 --- a/dist/build-rpm.sh +++ b/dist/build-rpm.sh @@ -1,4 +1,5 @@ #!/bin/bash +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. set -e diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/ContainerResources.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/ContainerResources.java index aad0b07a2c4..4c538d6a194 100644 --- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/ContainerResources.java +++ b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/ContainerResources.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.dockerapi; /** diff --git a/eval/src/apps/eval_expr/eval_expr.cpp b/eval/src/apps/eval_expr/eval_expr.cpp index 91c669efe94..afddec40e48 100644 --- a/eval/src/apps/eval_expr/eval_expr.cpp +++ b/eval/src/apps/eval_expr/eval_expr.cpp @@ -3,7 +3,7 @@ #include <vespa/eval/eval/function.h> #include <vespa/eval/eval/interpreted_function.h> #include <vespa/eval/eval/tensor_spec.h> - +#include <vespa/eval/eval/simple_tensor_engine.h> using namespace vespalib::eval; diff --git a/eval/src/apps/tensor_conformance/tensor_conformance.cpp b/eval/src/apps/tensor_conformance/tensor_conformance.cpp index 2d7cf9b5fa0..4130c75893b 100644 --- a/eval/src/apps/tensor_conformance/tensor_conformance.cpp +++ b/eval/src/apps/tensor_conformance/tensor_conformance.cpp @@ -113,9 +113,10 @@ TensorSpec eval_expr_tf(const Inspector &test, const TensorEngine &engine) { } SimpleObjectParams params(param_refs); NodeTypes types = NodeTypes(fun, get_types(param_values)); - const auto &tfun = make_tensor_function(engine, fun.root(), types, stash); - const Value &result = tfun.eval(engine, params, stash); - ASSERT_EQUAL(result.type(), tfun.result_type()); + const auto &plain_fun = make_tensor_function(engine, fun.root(), types, stash); + const auto &optimized = engine.optimize(plain_fun, stash); + const Value &result = optimized.eval(engine, params, stash); + ASSERT_EQUAL(result.type(), plain_fun.result_type()); ASSERT_EQUAL(result.type(), types.get_type(fun.root())); return engine.to_spec(result); } diff --git a/eval/src/tests/eval/function_speed/function_speed_test.cpp b/eval/src/tests/eval/function_speed/function_speed_test.cpp index 65866de7ddd..178ab32d734 100644 --- a/eval/src/tests/eval/function_speed/function_speed_test.cpp +++ b/eval/src/tests/eval/function_speed/function_speed_test.cpp @@ -4,6 +4,7 @@ #include <vespa/eval/eval/llvm/compiled_function.h> #include <vespa/vespalib/util/benchmark_timer.h> #include <vespa/eval/eval/interpreted_function.h> +#include <vespa/eval/eval/simple_tensor_engine.h> #include <vespa/vespalib/util/benchmark_timer.h> #include <vespa/eval/tensor/default_tensor_engine.h> diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp index af5935fbf1e..9cf5c31f76b 100644 --- a/eval/src/tests/eval/gbdt/gbdt_test.cpp +++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp @@ -6,6 +6,7 @@ #include <vespa/eval/eval/llvm/deinline_forest.h> #include <vespa/eval/eval/llvm/compiled_function.h> #include <vespa/eval/eval/interpreted_function.h> +#include <vespa/eval/eval/simple_tensor_engine.h> #include <vespa/vespalib/util/stringfmt.h> #include "model.cpp" diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp index 802f9555360..f0306e99a91 100644 --- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp +++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp @@ -6,6 +6,7 @@ #include <vespa/eval/eval/interpreted_function.h> #include <vespa/eval/eval/test/eval_spec.h> #include <vespa/eval/eval/basic_nodes.h> +#include <vespa/eval/eval/simple_tensor_engine.h> #include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/stash.h> @@ -177,7 +178,7 @@ struct InnerProduct { InterpretedFunction interpreted; ~InnerProduct() {} InnerProduct(const vespalib::string &expr) - : engine(SimpleTensorEngine::ref()), + : engine(DefaultTensorEngine::ref()), function(Function::parse({"a", "b"}, expr)), a("null"), b("null"), expect("null"), types(), @@ -186,10 +187,10 @@ struct InnerProduct { TensorSpec a_in, TensorSpec b_in, TensorSpec expect_in) - : engine(SimpleTensorEngine::ref()), + : engine(DefaultTensorEngine::ref()), function(Function::parse(expr)), a(a_in), b(b_in), expect(expect_in), - types(function, {ValueType::from_spec(a.type()), ValueType::from_spec(a.type())}), + types(function, {ValueType::from_spec(a.type()), ValueType::from_spec(b.type())}), interpreted(engine, function, types) {} void verify_optimized() const { EXPECT_EQUAL(1u, interpreted.program_size()); @@ -296,13 +297,13 @@ TEST("require that vector matrix multiplication works with tensor function") { TEST_DO(XW("reduce(join(b,a,f(x,y)(y*x)),sum,x)").verify_optimized()); } -TEST("require that matrix multiplication works with tensor function") { - TEST_DO(MatMul("reduce(a*b,sum,y)").verify_optimized()); - TEST_DO(MatMul("reduce(join(a,b,f(x,y)(x*y)),sum,y)").verify_optimized()); - TEST_DO(MatMul("reduce(b*a,sum,y)").verify_optimized()); - TEST_DO(MatMul("reduce(join(b,a,f(x,y)(x*y)),sum,y)").verify_optimized()); - TEST_DO(MatMul("reduce(join(a,b,f(x,y)(y*x)),sum,y)").verify_optimized()); - TEST_DO(MatMul("reduce(join(b,a,f(x,y)(y*x)),sum,y)").verify_optimized()); +TEST("require that matrix multiplication is not optimized (yet)") { + TEST_DO(MatMul("reduce(a*b,sum,y)").verify_not_optimized()); + TEST_DO(MatMul("reduce(join(a,b,f(x,y)(x*y)),sum,y)").verify_not_optimized()); + TEST_DO(MatMul("reduce(b*a,sum,y)").verify_not_optimized()); + TEST_DO(MatMul("reduce(join(b,a,f(x,y)(x*y)),sum,y)").verify_not_optimized()); + TEST_DO(MatMul("reduce(join(a,b,f(x,y)(y*x)),sum,y)").verify_not_optimized()); + TEST_DO(MatMul("reduce(join(b,a,f(x,y)(y*x)),sum,y)").verify_not_optimized()); } TEST("require that expressions similar to inner product are not optimized") { diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp index fb1ca3d18fe..b2df7eddd46 100644 --- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp +++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp @@ -35,7 +35,7 @@ struct EvalCtx { return fun.eval(engine, SimpleObjectParams(params), stash); } const TensorFunction &compile(const tensor_function::Node &expr) { - return engine.compile(expr, stash); + return engine.optimize(expr, stash); } Value::UP make_true() { return engine.from_spec(TensorSpec("double").add({}, 1.0)); diff --git a/eval/src/vespa/eval/eval/CMakeLists.txt b/eval/src/vespa/eval/eval/CMakeLists.txt index 3816780d4d9..0eabb4f4219 100644 --- a/eval/src/vespa/eval/eval/CMakeLists.txt +++ b/eval/src/vespa/eval/eval/CMakeLists.txt @@ -4,6 +4,7 @@ vespa_add_library(eval_eval OBJECT aggr.cpp basic_nodes.cpp call_nodes.cpp + compile_tensor_function.cpp delete_node.cpp function.cpp gbdt.cpp diff --git a/eval/src/vespa/eval/eval/compile_tensor_function.cpp b/eval/src/vespa/eval/eval/compile_tensor_function.cpp new file mode 100644 index 00000000000..ac36720895f --- /dev/null +++ b/eval/src/vespa/eval/eval/compile_tensor_function.cpp @@ -0,0 +1,83 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "compile_tensor_function.h" +#include "tensor_function.h" + +namespace vespalib::eval { + +namespace { + +using State = InterpretedFunction::State; +using Instruction = InterpretedFunction::Instruction; + +void op_skip(State &state, uint64_t param) { + state.program_offset += param; +} + +void op_skip_if_false(State &state, uint64_t param) { + ++state.if_cnt; + if (!state.peek(0).as_bool()) { + state.program_offset += param; + } + state.stack.pop_back(); +} + +struct Frame { + const TensorFunction &node; + std::vector<TensorFunction::Child::CREF> children; + size_t child_idx; + Frame(const TensorFunction &node_in) : node(node_in), children(), child_idx(0) { node.push_children(children); } + bool has_next_child() const { return (child_idx < children.size()); } + const TensorFunction &next_child() { return children[child_idx++].get().get(); } +}; + +struct ProgramCompiler { + Stash &stash; + std::vector<Frame> stack; + std::vector<Instruction> prog; + ProgramCompiler(Stash &stash_in) : stash(stash_in), stack(), prog() {} + + void append(const std::vector<Instruction> &other_prog) { + prog.insert(prog.end(), other_prog.begin(), other_prog.end()); + } + + void open(const TensorFunction &node) { + if (auto if_node = as<tensor_function::If>(node)) { + append(compile_tensor_function(if_node->cond(), stash)); + auto true_prog = compile_tensor_function(if_node->true_child(), stash); + auto false_prog = compile_tensor_function(if_node->false_child(), stash); + true_prog.emplace_back(op_skip, false_prog.size()); + prog.emplace_back(op_skip_if_false, true_prog.size()); + append(true_prog); + append(false_prog); + } else { + stack.emplace_back(node); + } + } + + void close(const TensorFunction &node) { + prog.push_back(node.compile_self(stash)); + } + + std::vector<Instruction> compile(const TensorFunction &function) { + open(function); + while (!stack.empty()) { + if (stack.back().has_next_child()) { + open(stack.back().next_child()); + } else { + close(stack.back().node); + stack.pop_back(); + } + } + return std::move(prog); + } +}; + +} // namespace vespalib::eval::<unnamed> + +std::vector<Instruction> compile_tensor_function(const TensorFunction &function, Stash &stash) { + ProgramCompiler compiler(stash); + return compiler.compile(function); +} + +} // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/compile_tensor_function.h b/eval/src/vespa/eval/eval/compile_tensor_function.h new file mode 100644 index 00000000000..bfac0e0f036 --- /dev/null +++ b/eval/src/vespa/eval/eval/compile_tensor_function.h @@ -0,0 +1,16 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "interpreted_function.h" +#include <vector> + +namespace vespalib { class Stash; } + +namespace vespalib::eval { + +class TensorFunction; + +std::vector<InterpretedFunction::Instruction> compile_tensor_function(const TensorFunction &function, Stash &stash); + +} // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp index 13ab6fe5676..28381030f24 100644 --- a/eval/src/vespa/eval/eval/interpreted_function.cpp +++ b/eval/src/vespa/eval/eval/interpreted_function.cpp @@ -6,434 +6,20 @@ #include "check_type.h" #include "tensor_spec.h" #include "operation.h" +#include "tensor_engine.h" #include <vespa/vespalib/util/classname.h> #include <vespa/eval/eval/llvm/compile_cache.h> #include <vespa/vespalib/util/benchmark_timer.h> #include <set> +#include "make_tensor_function.h" +#include "compile_tensor_function.h" + namespace vespalib { namespace eval { namespace { -using namespace nodes; -using State = InterpretedFunction::State; -using Instruction = InterpretedFunction::Instruction; -using map_fun_t = double (*)(double); -using join_fun_t = double (*)(double, double); - -//----------------------------------------------------------------------------- - -template <typename T, typename IN> -uint64_t wrap_param(const IN &value_in) { - const T &value = value_in; - return (uint64_t)&value; -} - -template <typename T> -const T &unwrap_param(uint64_t param) { return *((const T *)param); } - -//----------------------------------------------------------------------------- - -uint64_t to_param(map_fun_t value) { return (uint64_t)value; } -uint64_t to_param(join_fun_t value) { return (uint64_t)value; } -map_fun_t to_map_fun(uint64_t param) { return (map_fun_t)param; } -join_fun_t to_join_fun(uint64_t param) { return (join_fun_t)param; } - -//----------------------------------------------------------------------------- - -void op_load_const(State &state, uint64_t param) { - state.stack.push_back(unwrap_param<Value>(param)); -} - -void op_load_param(State &state, uint64_t param) { - state.stack.push_back(state.params->resolve(param, state.stash)); -} - -//----------------------------------------------------------------------------- - -void op_skip(State &state, uint64_t param) { - state.program_offset += param; -} - -void op_skip_if_false(State &state, uint64_t param) { - ++state.if_cnt; - if (!state.peek(0).as_bool()) { - state.program_offset += param; - } - state.stack.pop_back(); -} - -//----------------------------------------------------------------------------- - -void op_double_map(State &state, uint64_t param) { - state.replace(1, state.stash.create<DoubleValue>(to_map_fun(param)(state.peek(0).as_double()))); -} - -void op_double_mul(State &state, uint64_t) { - state.replace(2, state.stash.create<DoubleValue>(state.peek(1).as_double() * state.peek(0).as_double())); -} - -void op_double_add(State &state, uint64_t) { - state.replace(2, state.stash.create<DoubleValue>(state.peek(1).as_double() + state.peek(0).as_double())); -} - -void op_double_join(State &state, uint64_t param) { - state.replace(2, state.stash.create<DoubleValue>(to_join_fun(param)(state.peek(1).as_double(), state.peek(0).as_double()))); -} - -//----------------------------------------------------------------------------- - -void op_tensor_map(State &state, uint64_t param) { - state.replace(1, state.engine.map(state.peek(0), to_map_fun(param), state.stash)); -} - -void op_tensor_join(State &state, uint64_t param) { - state.replace(2, state.engine.join(state.peek(1), state.peek(0), to_join_fun(param), state.stash)); -} - -using ReduceParams = std::pair<Aggr,std::vector<vespalib::string>>; -void op_tensor_reduce(State &state, uint64_t param) { - const ReduceParams ¶ms = unwrap_param<ReduceParams>(param); - state.replace(1, state.engine.reduce(state.peek(0), params.first, params.second, state.stash)); -} - -using RenameParams = std::pair<std::vector<vespalib::string>,std::vector<vespalib::string>>; -void op_tensor_rename(State &state, uint64_t param) { - const RenameParams ¶ms = unwrap_param<RenameParams>(param); - state.replace(1, state.engine.rename(state.peek(0), params.first, params.second, state.stash)); -} - -void op_tensor_concat(State &state, uint64_t param) { - const vespalib::string &dimension = unwrap_param<vespalib::string>(param); - state.replace(2, state.engine.concat(state.peek(1), state.peek(0), dimension, state.stash)); -} - -//----------------------------------------------------------------------------- - -void op_tensor_function(State &state, uint64_t param) { - const TensorFunction &fun = unwrap_param<TensorFunction>(param); - state.stack.push_back(fun.eval(state.engine, *state.params, state.stash)); -} - -//----------------------------------------------------------------------------- - -bool step_labels(std::vector<double> &labels, const ValueType &type) { - for (size_t idx = labels.size(); idx-- > 0; ) { - labels[idx] += 1.0; - if (size_t(labels[idx]) < type.dimensions()[idx].size) { - return true; - } else { - labels[idx] = 0.0; - } - } - return false; -} - -//----------------------------------------------------------------------------- - -struct ProgramBuilder : public NodeVisitor, public NodeTraverser { - std::vector<Instruction> &program; - Stash &stash; - const TensorEngine &tensor_engine; - const NodeTypes &types; - - ProgramBuilder(std::vector<Instruction> &program_in, Stash &stash_in, const TensorEngine &tensor_engine_in, const NodeTypes &types_in) - : program(program_in), stash(stash_in), tensor_engine(tensor_engine_in), types(types_in) {} - - //------------------------------------------------------------------------- - - bool is_mul_join(const Node &node) const { - if (auto join = as<TensorJoin>(node)) { - if (auto mul = as<Mul>(join->lambda().root())) { - auto sym1 = as<Symbol>(mul->lhs()); - auto sym2 = as<Symbol>(mul->rhs()); - return (sym1 && sym2 && (sym1->id() != sym2->id())); - } - } - return false; - } - - bool is_mul(const Node &node) const { - auto mul = as<Mul>(node); - return (mul || is_mul_join(node)); - } - - bool is_typed_tensor(const Node &node) const { - const ValueType &type = types.get_type(node); - return (type.is_tensor() && !type.dimensions().empty()); - } - - bool is_typed_tensor_param(const Node &node) const { - auto sym = as<Symbol>(node); - return (sym && is_typed_tensor(node)); - } - - bool is_typed_tensor_product_of_params(const Node &node) const { - return (is_typed_tensor(node) && is_mul(node) && - is_typed_tensor_param(node.get_child(0)) && - is_typed_tensor_param(node.get_child(1))); - } - - //------------------------------------------------------------------------- - - void make_const_op(const Node &node, const Value &value) { - (void) node; - program.emplace_back(op_load_const, wrap_param<Value>(value)); - } - - void make_map_op(const Node &node, map_fun_t function) { - if (types.get_type(node).is_double()) { - program.emplace_back(op_double_map, to_param(function)); - } else { - program.emplace_back(op_tensor_map, to_param(function)); - } - } - - void make_join_op(const Node &node, join_fun_t function) { - if (types.get_type(node).is_double()) { - if (function == operation::Mul::f) { - program.emplace_back(op_double_mul); - } else if (function == operation::Add::f) { - program.emplace_back(op_double_add); - } else { - program.emplace_back(op_double_join, to_param(function)); - } - } else { - program.emplace_back(op_tensor_join, to_param(function)); - } - } - - //------------------------------------------------------------------------- - - void visit(const Number &node) override { - make_const_op(node, stash.create<DoubleValue>(node.value())); - } - void visit(const Symbol &node) override { - program.emplace_back(op_load_param, node.id()); - } - void visit(const String &node) override { - make_const_op(node, stash.create<DoubleValue>(node.hash())); - } - void visit(const In &node) override { - auto my_in = std::make_unique<In>(std::make_unique<Symbol>(0)); - for (size_t i = 0; i < node.num_entries(); ++i) { - my_in->add_entry(std::make_unique<Number>(node.get_entry(i).get_const_value())); - } - Function my_fun(std::move(my_in), {"x"}); - const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(my_fun, PassParams::SEPARATE)); - make_map_op(node, token.get()->get().get_function<1>()); - } - void visit(const Neg &node) override { - make_map_op(node, operation::Neg::f); - } - void visit(const Not &node) override { - make_map_op(node, operation::Not::f); - } - void visit(const If &node) override { - node.cond().traverse(*this); - size_t after_cond = program.size(); - program.emplace_back(op_skip_if_false); - node.true_expr().traverse(*this); - size_t after_true = program.size(); - program.emplace_back(op_skip); - node.false_expr().traverse(*this); - program[after_cond].update_param(after_true - after_cond); - program[after_true].update_param(program.size() - after_true - 1); - } - void visit(const Error &node) override { - make_const_op(node, ErrorValue::instance); - } - void visit(const TensorMap &node) override { - const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); - make_map_op(node, token.get()->get().get_function<1>()); - } - void visit(const TensorJoin &node) override { - const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); - make_join_op(node, token.get()->get().get_function<2>()); - } - void visit(const TensorReduce &node) override { - if ((node.aggr() == Aggr::SUM) && is_typed_tensor_product_of_params(node.get_child(0))) { - assert(program.size() >= 3); // load,load,mul - program.pop_back(); // mul - program.pop_back(); // load - program.pop_back(); // load - auto a = as<Symbol>(node.get_child(0).get_child(0)); - auto b = as<Symbol>(node.get_child(0).get_child(1)); - const auto &ir = tensor_function::reduce(tensor_function::join( - tensor_function::inject(types.get_type(*a), a->id(), stash), - tensor_function::inject(types.get_type(*b), b->id(), stash), - operation::Mul::f, stash), node.aggr(), node.dimensions(), stash); - const auto &fun = tensor_engine.compile(ir, stash); - program.emplace_back(op_tensor_function, wrap_param<TensorFunction>(fun)); - } else { - ReduceParams ¶ms = stash.create<ReduceParams>(node.aggr(), node.dimensions()); - program.emplace_back(op_tensor_reduce, wrap_param<ReduceParams>(params)); - } - } - void visit(const TensorRename &node) override { - RenameParams ¶ms = stash.create<RenameParams>(node.from(), node.to()); - program.emplace_back(op_tensor_rename, wrap_param<RenameParams>(params)); - } - void visit(const TensorLambda &node) override { - const auto &type = node.type(); - TensorSpec spec(type.to_spec()); - const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::ARRAY)); - auto fun = token.get()->get().get_function(); - std::vector<double> params(type.dimensions().size(), 0.0); - assert(token.get()->get().num_params() == params.size()); - do { - TensorSpec::Address addr; - for (size_t i = 0; i < params.size(); ++i) { - addr.emplace(type.dimensions()[i].name, size_t(params[i])); - } - spec.add(addr, fun(¶ms[0])); - } while (step_labels(params, type)); - make_const_op(node, *stash.create<Value::UP>(tensor_engine.from_spec(spec))); - } - void visit(const TensorConcat &node) override { - vespalib::string &dimension = stash.create<vespalib::string>(node.dimension()); - program.emplace_back(op_tensor_concat, wrap_param<vespalib::string>(dimension)); - } - void visit(const Add &node) override { - make_join_op(node, operation::Add::f); - } - void visit(const Sub &node) override { - make_join_op(node, operation::Sub::f); - } - void visit(const Mul &node) override { - make_join_op(node, operation::Mul::f); - } - void visit(const Div &node) override { - make_join_op(node, operation::Div::f); - } - void visit(const Mod &node) override { - make_join_op(node, operation::Mod::f); - } - void visit(const Pow &node) override { - make_join_op(node, operation::Pow::f); - } - void visit(const Equal &node) override { - make_join_op(node, operation::Equal::f); - } - void visit(const NotEqual &node) override { - make_join_op(node, operation::NotEqual::f); - } - void visit(const Approx &node) override { - make_join_op(node, operation::Approx::f); - } - void visit(const Less &node) override { - make_join_op(node, operation::Less::f); - } - void visit(const LessEqual &node) override { - make_join_op(node, operation::LessEqual::f); - } - void visit(const Greater &node) override { - make_join_op(node, operation::Greater::f); - } - void visit(const GreaterEqual &node) override { - make_join_op(node, operation::GreaterEqual::f); - } - void visit(const And &node) override { - make_join_op(node, operation::And::f); - } - void visit(const Or &node) override { - make_join_op(node, operation::Or::f); - } - void visit(const Cos &node) override { - make_map_op(node, operation::Cos::f); - } - void visit(const Sin &node) override { - make_map_op(node, operation::Sin::f); - } - void visit(const Tan &node) override { - make_map_op(node, operation::Tan::f); - } - void visit(const Cosh &node) override { - make_map_op(node, operation::Cosh::f); - } - void visit(const Sinh &node) override { - make_map_op(node, operation::Sinh::f); - } - void visit(const Tanh &node) override { - make_map_op(node, operation::Tanh::f); - } - void visit(const Acos &node) override { - make_map_op(node, operation::Acos::f); - } - void visit(const Asin &node) override { - make_map_op(node, operation::Asin::f); - } - void visit(const Atan &node) override { - make_map_op(node, operation::Atan::f); - } - void visit(const Exp &node) override { - make_map_op(node, operation::Exp::f); - } - void visit(const Log10 &node) override { - make_map_op(node, operation::Log10::f); - } - void visit(const Log &node) override { - make_map_op(node, operation::Log::f); - } - void visit(const Sqrt &node) override { - make_map_op(node, operation::Sqrt::f); - } - void visit(const Ceil &node) override { - make_map_op(node, operation::Ceil::f); - } - void visit(const Fabs &node) override { - make_map_op(node, operation::Fabs::f); - } - void visit(const Floor &node) override { - make_map_op(node, operation::Floor::f); - } - void visit(const Atan2 &node) override { - make_join_op(node, operation::Atan2::f); - } - void visit(const Ldexp &node) override { - make_join_op(node, operation::Ldexp::f); - } - void visit(const Pow2 &node) override { - make_join_op(node, operation::Pow::f); - } - void visit(const Fmod &node) override { - make_join_op(node, operation::Mod::f); - } - void visit(const Min &node) override { - make_join_op(node, operation::Min::f); - } - void visit(const Max &node) override { - make_join_op(node, operation::Max::f); - } - void visit(const IsNan &node) override { - make_map_op(node, operation::IsNan::f); - } - void visit(const Relu &node) override { - make_map_op(node, operation::Relu::f); - } - void visit(const Sigmoid &node) override { - make_map_op(node, operation::Sigmoid::f); - } - void visit(const Elu &node) override { - make_map_op(node, operation::Elu::f); - } - - //------------------------------------------------------------------------- - - bool open(const Node &node) override { - if (check_type<If>(node)) { - node.accept(*this); - return false; - } - return true; - } - - void close(const Node &node) override { - node.accept(*this); - } -}; - const Function *get_lambda(const nodes::Node &node) { if (auto ptr = as<nodes::TensorMap>(node)) { return &ptr->lambda(); @@ -489,8 +75,9 @@ InterpretedFunction::InterpretedFunction(const TensorEngine &engine, const nodes _num_params(num_params_in), _tensor_engine(engine) { - ProgramBuilder program_builder(_program, _stash, _tensor_engine, types); - root.traverse(program_builder); + const TensorFunction &plain_fun = make_tensor_function(engine, root, types, _stash); + const TensorFunction &optimized = engine.optimize(plain_fun, _stash); + _program = compile_tensor_function(optimized, _stash); } InterpretedFunction::~InterpretedFunction() {} diff --git a/eval/src/vespa/eval/eval/interpreted_function.h b/eval/src/vespa/eval/eval/interpreted_function.h index 2a52a5a8258..1c57b20682f 100644 --- a/eval/src/vespa/eval/eval/interpreted_function.h +++ b/eval/src/vespa/eval/eval/interpreted_function.h @@ -3,7 +3,6 @@ #pragma once #include "function.h" -#include "simple_tensor_engine.h" #include "node_types.h" #include "lazy_params.h" #include <vespa/vespalib/util/stash.h> diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index d28c4812a31..d84d9f53749 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -32,6 +32,21 @@ bool step_labels(std::vector<double> &labels, const ValueType &type) { return false; } +// TODO(havardpe): generic function pointer resolving for all single +// operation lambdas. + +template <typename OP2> +bool is_op2(const Function &lambda) { + if (lambda.num_params() == 2) { + if (auto op2 = as<OP2>(lambda.root())) { + auto sym1 = as<Symbol>(op2->lhs()); + auto sym2 = as<Symbol>(op2->rhs()); + return (sym1 && sym2 && (sym1->id() != sym2->id())); + } + } + return false; +} + //----------------------------------------------------------------------------- struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { @@ -135,8 +150,14 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { make_map(node, token.get()->get().get_function<1>()); } void visit(const TensorJoin &node) override { - const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); - make_join(node, token.get()->get().get_function<2>()); + if (is_op2<Mul>(node.lambda())) { + make_join(node, operation::Mul::f); + } else if (is_op2<Add>(node.lambda())) { + make_join(node, operation::Add::f); + } else { + const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); + make_join(node, token.get()->get().get_function<2>()); + } } void visit(const TensorReduce &node) override { make_reduce(node, node.aggr(), node.dimensions()); diff --git a/eval/src/vespa/eval/eval/tensor_engine.h b/eval/src/vespa/eval/eval/tensor_engine.h index 02a7f0c655a..a01a6f889fd 100644 --- a/eval/src/vespa/eval/eval/tensor_engine.h +++ b/eval/src/vespa/eval/eval/tensor_engine.h @@ -47,7 +47,7 @@ struct TensorEngine virtual void encode(const Value &value, nbostream &output) const = 0; virtual Value::UP decode(nbostream &input) const = 0; - virtual const TensorFunction &compile(const tensor_function::Node &expr, Stash &) const { return expr; } + virtual const TensorFunction &optimize(const TensorFunction &expr, Stash &) const { return expr; } virtual const Value &map(const Value &a, map_fun_t function, Stash &stash) const = 0; virtual const Value &join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const = 0; diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index 8427cc53a16..62e547cbd7e 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -11,6 +11,86 @@ namespace vespalib { namespace eval { namespace tensor_function { +namespace { + +using State = InterpretedFunction::State; +using Instruction = InterpretedFunction::Instruction; + +//----------------------------------------------------------------------------- + +template <typename T, typename IN> +uint64_t wrap_param(const IN &value_in) { + const T &value = value_in; + return (uint64_t)&value; +} + +template <typename T> +const T &unwrap_param(uint64_t param) { return *((const T *)param); } + +//----------------------------------------------------------------------------- + +uint64_t to_param(map_fun_t value) { return (uint64_t)value; } +uint64_t to_param(join_fun_t value) { return (uint64_t)value; } +map_fun_t to_map_fun(uint64_t param) { return (map_fun_t)param; } +join_fun_t to_join_fun(uint64_t param) { return (join_fun_t)param; } + +//----------------------------------------------------------------------------- + +void op_load_const(State &state, uint64_t param) { + state.stack.push_back(unwrap_param<Value>(param)); +} + +void op_load_param(State &state, uint64_t param) { + state.stack.push_back(state.params->resolve(param, state.stash)); +} + +//----------------------------------------------------------------------------- + +void op_double_map(State &state, uint64_t param) { + state.replace(1, state.stash.create<DoubleValue>(to_map_fun(param)(state.peek(0).as_double()))); +} + +void op_double_mul(State &state, uint64_t) { + state.replace(2, state.stash.create<DoubleValue>(state.peek(1).as_double() * state.peek(0).as_double())); +} + +void op_double_add(State &state, uint64_t) { + state.replace(2, state.stash.create<DoubleValue>(state.peek(1).as_double() + state.peek(0).as_double())); +} + +void op_double_join(State &state, uint64_t param) { + state.replace(2, state.stash.create<DoubleValue>(to_join_fun(param)(state.peek(1).as_double(), state.peek(0).as_double()))); +} + +//----------------------------------------------------------------------------- + +void op_tensor_map(State &state, uint64_t param) { + state.replace(1, state.engine.map(state.peek(0), to_map_fun(param), state.stash)); +} + +void op_tensor_join(State &state, uint64_t param) { + state.replace(2, state.engine.join(state.peek(1), state.peek(0), to_join_fun(param), state.stash)); +} + +using ReduceParams = std::pair<Aggr,std::vector<vespalib::string>>; +void op_tensor_reduce(State &state, uint64_t param) { + const ReduceParams ¶ms = unwrap_param<ReduceParams>(param); + state.replace(1, state.engine.reduce(state.peek(0), params.first, params.second, state.stash)); +} + +using RenameParams = std::pair<std::vector<vespalib::string>,std::vector<vespalib::string>>; +void op_tensor_rename(State &state, uint64_t param) { + const RenameParams ¶ms = unwrap_param<RenameParams>(param); + state.replace(1, state.engine.rename(state.peek(0), params.first, params.second, state.stash)); +} + +void op_tensor_concat(State &state, uint64_t param) { + const vespalib::string &dimension = unwrap_param<vespalib::string>(param); + state.replace(2, state.engine.concat(state.peek(1), state.peek(0), dimension, state.stash)); +} + +} // namespace vespalib::eval::tensor_function + //----------------------------------------------------------------------------- void @@ -39,6 +119,12 @@ ConstValue::eval(const TensorEngine &, const LazyParams &, Stash &) const return _value; } +Instruction +ConstValue::compile_self(Stash &) const +{ + return Instruction(op_load_const, wrap_param<Value>(_value)); +} + //----------------------------------------------------------------------------- const Value & @@ -47,6 +133,12 @@ Inject::eval(const TensorEngine &, const LazyParams ¶ms, Stash &stash) const return params.resolve(_param_idx, stash); } +Instruction +Inject::compile_self(Stash &) const +{ + return Instruction(op_load_param, _param_idx); +} + //----------------------------------------------------------------------------- const Value & @@ -56,6 +148,13 @@ Reduce::eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) return engine.reduce(a, _aggr, _dimensions, stash); } +Instruction +Reduce::compile_self(Stash &stash) const +{ + ReduceParams ¶ms = stash.create<ReduceParams>(_aggr, _dimensions); + return Instruction(op_tensor_reduce, wrap_param<ReduceParams>(params)); +} + //----------------------------------------------------------------------------- const Value & @@ -65,6 +164,15 @@ Map::eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) co return engine.map(a, _function, stash); } +Instruction +Map::compile_self(Stash &) const +{ + if (result_type().is_double()) { + return Instruction(op_double_map, to_param(_function)); + } + return Instruction(op_tensor_map, to_param(_function)); +} + //----------------------------------------------------------------------------- const Value & @@ -75,6 +183,21 @@ Join::eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) c return engine.join(a, b, _function, stash); } +Instruction +Join::compile_self(Stash &) const +{ + if (result_type().is_double()) { + if (_function == operation::Mul::f) { + return Instruction(op_double_mul); + } + if (_function == operation::Add::f) { + return Instruction(op_double_add); + } + return Instruction(op_double_join, to_param(_function)); + } + return Instruction(op_tensor_join, to_param(_function)); +} + //----------------------------------------------------------------------------- const Value & @@ -85,6 +208,12 @@ Concat::eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) return engine.concat(a, b, _dimension, stash); } +Instruction +Concat::compile_self(Stash &) const +{ + return Instruction(op_tensor_concat, wrap_param<vespalib::string>(_dimension)); +} + //----------------------------------------------------------------------------- const Value & @@ -94,6 +223,13 @@ Rename::eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) return engine.rename(a, _from, _to, stash); } +Instruction +Rename::compile_self(Stash &stash) const +{ + RenameParams ¶ms = stash.create<RenameParams>(_from, _to); + return Instruction(op_tensor_rename, wrap_param<RenameParams>(params)); +} + //----------------------------------------------------------------------------- void @@ -112,6 +248,14 @@ If::eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) con : false_child().eval(engine, params, stash)); } +Instruction +If::compile_self(Stash &) const +{ + // 'if' is handled directly by compile_tensor_function to enable + // lazy-evaluation of true/false sub-expressions. + abort(); +} + //----------------------------------------------------------------------------- const Node &const_value(const Value &value, Stash &stash) { diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index d9ee5cc068c..c739ea8cba9 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -12,6 +12,8 @@ #include "value.h" #include "aggr.h" +#include "interpreted_function.h" + namespace vespalib { class Stash; @@ -75,6 +77,17 @@ struct TensorFunction virtual void push_children(std::vector<Child::CREF> &children) const = 0; /** + * Compile this node into a single instruction that can be run by + * an interpreted function. Sub-expressions are compiled as + * separate instructions and their results will be available on + * the value stack during execution. + * + * @return instruction representing the operation of this node + * @param stash heterogeneous object store + **/ + virtual InterpretedFunction::Instruction compile_self(Stash &stash) const = 0; + + /** * Evaluate this tensor function based on the given * parameters. The given stash can be used to store temporary * objects that need to be kept alive for the return value to be @@ -157,6 +170,7 @@ private: public: ConstValue(const Value &value_in) : Leaf(value_in.type()), _value(value_in) {} const Value &eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &) const final override; + InterpretedFunction::Instruction compile_self(Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -170,6 +184,7 @@ public: : Leaf(result_type_in), _param_idx(param_idx_in) {} size_t param_idx() const { return _param_idx; } const Value &eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &) const final override; + InterpretedFunction::Instruction compile_self(Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -188,6 +203,7 @@ public: Aggr aggr() const { return _aggr; } const std::vector<vespalib::string> &dimensions() const { return _dimensions; } const Value &eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) const final override; + InterpretedFunction::Instruction compile_self(Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -203,6 +219,7 @@ public: : Op1(result_type_in, child_in), _function(function_in) {} map_fun_t function() const { return _function; } const Value &eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) const final override; + InterpretedFunction::Instruction compile_self(Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -219,6 +236,7 @@ public: : Op2(result_type_in, lhs_in, rhs_in), _function(function_in) {} join_fun_t function() const { return _function; } const Value &eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) const final override; + InterpretedFunction::Instruction compile_self(Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -235,6 +253,7 @@ public: : Op2(result_type_in, lhs_in, rhs_in), _dimension(dimension_in) {} const vespalib::string &dimension() const { return _dimension; } const Value &eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) const final override; + InterpretedFunction::Instruction compile_self(Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -253,6 +272,7 @@ public: const std::vector<vespalib::string> &from() const { return _from; } const std::vector<vespalib::string> &to() const { return _to; } const Value &eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) const final override; + InterpretedFunction::Instruction compile_self(Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -274,6 +294,7 @@ public: const TensorFunction &false_child() const { return _false_child.get(); } void push_children(std::vector<Child::CREF> &children) const final override; const Value &eval(const TensorEngine &engine, const LazyParams ¶ms, Stash &stash) const final override; + InterpretedFunction::Instruction compile_self(Stash &stash) const final override; }; //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index 8fe0732f3c4..43ced9513f0 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -276,7 +276,7 @@ struct RetainedReduce : Eval { auto a_type = ValueType::from_spec(a.type()); const auto &ir = tensor_function::reduce(tensor_function::inject(a_type, tensor_id_a, stash), aggr, dimensions, stash); ValueType expect_type = ir.result_type(); - const auto &fun = engine.compile(ir, stash); + const auto &fun = engine.optimize(ir, stash); Input input(engine.from_spec(a)); return Result(engine, check_type(fun.eval(engine, input.get(), stash), expect_type)); } @@ -291,7 +291,7 @@ struct RetainedMap : Eval { auto a_type = ValueType::from_spec(a.type()); const auto &ir = tensor_function::map(tensor_function::inject(a_type, tensor_id_a, stash), function, stash); ValueType expect_type = ir.result_type(); - const auto &fun = engine.compile(ir, stash); + const auto &fun = engine.optimize(ir, stash); Input input(engine.from_spec(a)); return Result(engine, check_type(fun.eval(engine, input.get(), stash), expect_type)); } @@ -309,7 +309,7 @@ struct RetainedJoin : Eval { tensor_function::inject(b_type, tensor_id_b, stash), function, stash); ValueType expect_type = ir.result_type(); - const auto &fun = engine.compile(ir, stash); + const auto &fun = engine.optimize(ir, stash); Input input(engine.from_spec(a), engine.from_spec(b)); return Result(engine, check_type(fun.eval(engine, input.get(), stash), expect_type)); } diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index c9f3be9d588..9477b36463a 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -206,17 +206,13 @@ DefaultTensorEngine::decode(nbostream &input) const //----------------------------------------------------------------------------- const TensorFunction & -DefaultTensorEngine::compile(const eval::tensor_function::Node &expr, Stash &stash) const +DefaultTensorEngine::optimize(const TensorFunction &expr, Stash &stash) const { - using Node = eval::tensor_function::Node; - using Child = Node::Child; + using Child = TensorFunction::Child; Child root(expr); std::vector<Child::CREF> nodes({root}); for (size_t i = 0; i < nodes.size(); ++i) { - const Child &child = nodes[i]; - const Node *node = dynamic_cast<const Node *>(&child.get()); - assert(node != nullptr); - node->push_children(nodes); + nodes[i].get().get().push_children(nodes); } while (!nodes.empty()) { const Child &child = nodes.back(); diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.h b/eval/src/vespa/eval/tensor/default_tensor_engine.h index 1cef4ba2d35..755bdcf6a9d 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.h +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.h @@ -25,7 +25,7 @@ public: void encode(const Value &value, nbostream &output) const override; Value::UP decode(nbostream &input) const override; - const TensorFunction &compile(const eval::tensor_function::Node &expr, Stash &stash) const override; + const TensorFunction &optimize(const TensorFunction &expr, Stash &stash) const override; const Value &map(const Value &a, map_fun_t function, Stash &stash) const override; const Value &join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const override; diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp index 9f09940806b..0f395bd353b 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp @@ -26,6 +26,17 @@ getCellsRef(const eval::Value &value) return denseTensor.cellsRef(); } +void op_call_leaf_eval(eval::InterpretedFunction::State &state, uint64_t param) { + DenseDotProductFunction *self = (DenseDotProductFunction *)(param); + state.stack.push_back(self->eval(state.engine, *state.params, state.stash)); +} + +} + +eval::InterpretedFunction::Instruction +DenseDotProductFunction::compile_self(Stash &) const +{ + return eval::InterpretedFunction::Instruction(op_call_leaf_eval, (uint64_t)(this)); } const eval::Value & diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h index 4e3a54ca18d..d313602bd53 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h +++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h @@ -23,6 +23,7 @@ public: size_t rhsTensorId() const { return _rhsTensorId; } const eval::ValueType &result_type() const override { return eval::DoubleValue::double_type(); } void push_children(std::vector<Child::CREF> &) const override {} + eval::InterpretedFunction::Instruction compile_self(Stash &stash) const override; const eval::Value &eval(const eval::TensorEngine &engine, const eval::LazyParams ¶ms, Stash &stash) const override; }; diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp index 50ab6efc931..a62dafb6831 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp @@ -69,8 +69,19 @@ getCellsRef(const eval::Value &value) return denseTensor.cellsRef(); } +void op_call_leaf_eval(eval::InterpretedFunction::State &state, uint64_t param) { + DenseXWProductFunction *self = (DenseXWProductFunction *)(param); + state.stack.push_back(self->eval(state.engine, *state.params, state.stash)); +} + } // namespace <unnamed> +eval::InterpretedFunction::Instruction +DenseXWProductFunction::compile_self(Stash &) const +{ + return eval::InterpretedFunction::Instruction(op_call_leaf_eval, (uint64_t)(this)); +} + const eval::Value & DenseXWProductFunction::eval(const eval::TensorEngine &, const eval::LazyParams ¶ms, Stash &stash) const { diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h index c6a466dc527..4d2a85d96f7 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h +++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h @@ -47,6 +47,7 @@ public: const eval::ValueType &result_type() const override { return _resultType; } void push_children(std::vector<Child::CREF> &) const override {} + eval::InterpretedFunction::Instruction compile_self(Stash &stash) const override; const eval::Value &eval(const eval::TensorEngine &engine, const eval::LazyParams ¶ms, Stash &stash) const override; }; diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java index b2d1af15867..e9d2e9f7e8a 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.filedistribution; import com.yahoo.config.FileReference; diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java index eb69a1492bf..c6b2dd32be7 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.filedistribution; import com.google.common.util.concurrent.SettableFuture; diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceData.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceData.java index dabdba2bfc0..ceb43ab3d51 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceData.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceData.java @@ -64,4 +64,10 @@ public abstract class FileReferenceData { * @return number of bytes */ public abstract long size(); + + /** + * Close underlying files + * + */ + public abstract void close(); } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDataBlob.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDataBlob.java index 3759cbe2ef7..f0db12a45fc 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDataBlob.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDataBlob.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.filedistribution; import com.yahoo.config.FileReference; @@ -41,4 +42,9 @@ public class FileReferenceDataBlob extends FileReferenceData { public long size() { return content.length; } + + @Override + public void close() { + // no-op + } } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownload.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownload.java index 048287f0892..904eaae8c4a 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownload.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownload.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.filedistribution; diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java index 6fac2becf1b..20ad2e48fe2 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.filedistribution; import com.google.common.collect.ImmutableMap; diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/LazyFileReferenceData.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/LazyFileReferenceData.java index 1681843a818..0bc8f3b162a 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/LazyFileReferenceData.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/LazyFileReferenceData.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.filedistribution; import com.yahoo.config.FileReference; @@ -49,4 +50,12 @@ public class LazyFileReferenceData extends FileReferenceData { throw new RuntimeException(e); } } + + public void close() { + try { + channel.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/RpcTester.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/RpcTester.java index 28935c203fe..26cc3553c1f 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/RpcTester.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/RpcTester.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.filedistribution; diff --git a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java index dc19c7521a9..589fbe29abf 100644 --- a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java +++ b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.filedistribution; diff --git a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileReceiverTest.java b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileReceiverTest.java index 762817c27ef..afa66f89efc 100644 --- a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileReceiverTest.java +++ b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileReceiverTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.filedistribution; import com.yahoo.config.FileReference; diff --git a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/status/FileDistributionStatusClientTest.java b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/status/FileDistributionStatusClientTest.java index fcbe880bfc7..43eb006cc6d 100644 --- a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/status/FileDistributionStatusClientTest.java +++ b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/status/FileDistributionStatusClientTest.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.filedistribution.status; diff --git a/node-admin/pom.xml b/node-admin/pom.xml index 7b3b787b503..161769a4edf 100644 --- a/node-admin/pom.xml +++ b/node-admin/pom.xml @@ -116,12 +116,11 @@ <scope>test</scope> </dependency> <dependency> - <groupId>com.google.jimfs</groupId> - <artifactId>jimfs</artifactId> - <scope>test</scope> + <groupId>org.apache.velocity</groupId> + <artifactId>velocity</artifactId> + <scope>compile</scope> </dependency> </dependencies> - <build> <plugins> <plugin> diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/TaskContext.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/TaskContext.java index 9def627e87f..87491367514 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/TaskContext.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/component/TaskContext.java @@ -18,4 +18,6 @@ public interface TaskContext { FileSystem fileSystem(); void logSystemModification(Logger logger, String actionDescription); + + default boolean executeSubtask(IdempotentTask task) { return false; } } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java index 32f6186707a..edf4f059fc2 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java @@ -58,6 +58,7 @@ public class NodeAgentImpl implements NodeAgent { private boolean isFrozen = true; private boolean wantFrozen = false; private boolean workToDoNow = true; + private boolean expectNodeNotInNodeRepo = false; private final Object monitor = new Object(); @@ -378,7 +379,9 @@ public class NodeAgentImpl implements NodeAgent { boolean isFrozenCopy; synchronized (monitor) { while (!workToDoNow) { - long remainder = timeBetweenEachConverge.minus(Duration.between(lastConverge, clock.instant())).toMillis(); + long remainder = timeBetweenEachConverge + .minus(Duration.between(lastConverge, clock.instant())) + .toMillis(); if (remainder > 0) { try { monitor.wait(remainder); @@ -413,7 +416,7 @@ public class NodeAgentImpl implements NodeAgent { // therefore be reset if we get an exception from docker. numberOfUnhandledException++; containerState = UNKNOWN; - logger.error("Caught a DockerExecption, resetting containerState to " + containerState, e); + logger.error("Caught a DockerException, resetting containerState to " + containerState, e); } catch (Exception e) { numberOfUnhandledException++; logger.error("Unhandled exception, ignoring.", e); @@ -427,9 +430,15 @@ public class NodeAgentImpl implements NodeAgent { // Public for testing void converge() { - final ContainerNodeSpec nodeSpec = nodeRepository.getContainerNodeSpec(hostname) - .orElseThrow(() -> - new IllegalStateException(String.format("Node '%s' missing from node repository.", hostname))); + final Optional<ContainerNodeSpec> nodeSpecOptional = nodeRepository.getContainerNodeSpec(hostname); + + // We just removed the node from node repo, so this is expected until NodeAdmin stop this NodeAgent + if (!nodeSpecOptional.isPresent() && expectNodeNotInNodeRepo) return; + + final ContainerNodeSpec nodeSpec = nodeSpecOptional.orElseThrow(() -> + new IllegalStateException(String.format("Node '%s' missing from node repository.", hostname))); + expectNodeNotInNodeRepo = false; + Optional<Container> container = getContainer(); if (!nodeSpec.equals(lastNodeSpec)) { @@ -499,6 +508,7 @@ public class NodeAgentImpl implements NodeAgent { storageMaintainer.cleanupNodeStorage(containerName, nodeSpec); updateNodeRepoWithCurrentAttributes(nodeSpec); nodeRepository.markNodeAvailableForNewAllocation(hostname); + expectNodeNotInNodeRepo = true; break; default: throw new RuntimeException("UNKNOWN STATE " + nodeSpec.nodeState.name()); diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/DebugHandlerHelper.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/DebugHandlerHelper.java new file mode 100644 index 00000000000..dfcaba7c4bb --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/DebugHandlerHelper.java @@ -0,0 +1,53 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.provider; + +import javax.annotation.concurrent.ThreadSafe; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +/** + * Class to make it easier to implement a NodeAdminDebugHandler: + * - Forward to sub-NodeAdminDebugHandlers with addHandler, + * - Specify constants with addConstant + * - Forwarding to methods that dynamically build debug objects with addThreadSafeSupplier. + * + * @author hakonhall + */ +@ThreadSafe +public class DebugHandlerHelper implements NodeAdminDebugHandler { + private Object monitor = new Object(); + private final ConcurrentMap<String, Supplier<Object>> suppliers = new ConcurrentHashMap<>(); + + public void addThreadSafeSupplier(String name, Supplier<Object> threadSafeSupplier) { + Supplier<Object> previousSupplier = suppliers.putIfAbsent(name, threadSafeSupplier); + if (previousSupplier != null) { + throw new IllegalArgumentException(name + " is already registered"); + } + } + + public void addHandler(String name, NodeAdminDebugHandler handler) { + addThreadSafeSupplier(name, () -> handler.getDebugPage()); + } + + public void addConstant(String name, String value) { + addThreadSafeSupplier(name, () -> value); + } + + public void remove(String name) { + Supplier<Object> supplier = suppliers.remove(name); + if (supplier == null) { + throw new IllegalArgumentException(name + " is not registered"); + } + } + + @Override + public Map<String, Object> getDebugPage() { + return suppliers.entrySet().stream().collect(Collectors.toMap( + Map.Entry::getKey, + entry -> entry.getValue().get())); + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminDebugHandler.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminDebugHandler.java new file mode 100644 index 00000000000..7b5eaa2f326 --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminDebugHandler.java @@ -0,0 +1,20 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.provider; + +import javax.annotation.concurrent.ThreadSafe; +import java.util.Map; + +/** + * Interface for supporting debug info to introspect e.g. internal state. + * + * @author hakonhall + */ +@ThreadSafe +public interface NodeAdminDebugHandler { + /** + * The Object in the map values must be serializable with Jackson's ObjectMapper. + * May be called concurrently by different threads. + */ + Map<String, Object> getDebugPage(); +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminStateUpdater.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminStateUpdater.java index 755e1301c12..841f464e014 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminStateUpdater.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminStateUpdater.java @@ -1,9 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.provider; -import java.util.Map; +import javax.annotation.concurrent.ThreadSafe; -public interface NodeAdminStateUpdater { +@ThreadSafe +public interface NodeAdminStateUpdater extends NodeAdminDebugHandler { enum State { TRANSITIONING, RESUMED, SUSPENDED_NODE_ADMIN, SUSPENDED} /** @@ -12,6 +13,4 @@ public interface NodeAdminStateUpdater { * has converged. */ boolean setResumeStateAndCheckIfResumed(State wantedState); - - Map<String, Object> getDebugPage(); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributes.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributes.java new file mode 100644 index 00000000000..3910398a040 --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributes.java @@ -0,0 +1,27 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import java.nio.file.attribute.PosixFileAttributes; +import java.nio.file.attribute.PosixFilePermissions; +import java.time.Instant; + +/** + * This wrapper around PosixFileAttributes. + * + * @author hakonhall + */ +public class FileAttributes { + private final PosixFileAttributes attributes; + + FileAttributes(PosixFileAttributes attributes) { + this.attributes = attributes; + } + + public Instant lastModifiedTime() { return attributes.lastModifiedTime().toInstant(); } + public String owner() { return attributes.owner().getName(); } + public String group() { return attributes.group().getName(); } + public String permissions() { return PosixFilePermissions.toString(attributes.permissions()); } + public boolean isRegularFile() { return attributes.isRegularFile(); } + public boolean isDirectory() { return attributes.isDirectory(); } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCache.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCache.java new file mode 100644 index 00000000000..12a9609f89c --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCache.java @@ -0,0 +1,44 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import java.util.Optional; + +// @ThreadUnsafe +public class FileAttributesCache { + private final UnixPath path; + + private Optional<FileAttributes> attributes = Optional.empty(); + + public FileAttributesCache(UnixPath path) { + this.path = path; + } + + public FileAttributes get() { + if (!attributes.isPresent()) { + attributes = Optional.of(path.getAttributes()); + } + + return attributes.get(); + } + + public FileAttributes forceGet() { + attributes = Optional.empty(); + return get(); + } + + public boolean exists() { + if (attributes.isPresent()) { + return true; + } + + Optional<FileAttributes> attributes = path.getAttributesIfExists(); + if (attributes.isPresent()) { + // Might as well update this.attributes + this.attributes = attributes; + return true; + } else { + return false; + } + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileContentCache.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileContentCache.java new file mode 100644 index 00000000000..ca79e8bb113 --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileContentCache.java @@ -0,0 +1,36 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import java.time.Instant; +import java.util.Optional; + +/** + * Class to avoid repeated reads of file content when the file seldom changes. + * + * @author hakonhall + */ +class FileContentCache { + private final UnixPath path; + + private Optional<String> value = Optional.empty(); + private Optional<Instant> modifiedTime = Optional.empty(); + + FileContentCache(UnixPath path) { + this.path = path; + } + + String get(Instant lastModifiedTime) { + if (!value.isPresent() || lastModifiedTime.compareTo(modifiedTime.get()) > 0) { + value = Optional.of(path.readUtf8File()); + modifiedTime = Optional.of(lastModifiedTime); + } + + return value.get(); + } + + void updateWith(String content, Instant modifiedTime) { + this.value = Optional.of(content); + this.modifiedTime = Optional.of(modifiedTime); + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileSync.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileSync.java new file mode 100644 index 00000000000..d8b8aadfff7 --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileSync.java @@ -0,0 +1,118 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import com.yahoo.vespa.hosted.node.admin.component.TaskContext; + +import java.nio.file.Path; +import java.util.Objects; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.logging.Logger; + +/** + * Class to minimize resource usage with repetitive and mostly identical, idempotent, and + * mutating file operations, e.g. setting file content, setting owner, etc. + * + * Only changes to the file is logged. + * + * @author hakohall + */ +// @ThreadUnsafe +public class FileSync { + private static final Logger logger = Logger.getLogger(FileSync.class.getName()); + + private final UnixPath path; + private final FileContentCache contentCache; + + public FileSync(Path path) { + this.path = new UnixPath(path); + this.contentCache = new FileContentCache(this.path); + } + + /** + * CPU, I/O, and memory usage is optimized for repeated calls with the same arguments. + * @return true if the system was modified: content was written, or owner was set, etc. + * system is only modified if necessary (different). + */ + public boolean convergeTo(TaskContext taskContext, PartialFileData partialFileData) { + FileAttributesCache currentAttributes = new FileAttributesCache(path); + + boolean modifiedSystem = false; + + modifiedSystem |= maybeUpdateContent(taskContext, partialFileData.getContent(), currentAttributes); + + modifiedSystem |= convergeAttribute( + taskContext, + "owner", + partialFileData.getOwner(), + () -> currentAttributes.get().owner(), + path::setOwner); + + modifiedSystem |= convergeAttribute( + taskContext, + "group", + partialFileData.getGroup(), + () -> currentAttributes.get().group(), + path::setGroup); + + modifiedSystem |= convergeAttribute( + taskContext, + "permissions", + partialFileData.getPermissions(), + () -> currentAttributes.get().permissions(), + path::setPermissions); + + return modifiedSystem; + } + + private boolean convergeAttribute(TaskContext taskContext, + String attributeName, + Optional<String> wantedValue, + Supplier<String> currentValueSupplier, + Consumer<String> valueSetter) { + if (!wantedValue.isPresent()) { + return false; + } + + String currentValue = currentValueSupplier.get(); + if (Objects.equals(wantedValue.get(), currentValue)) { + return false; + } else { + String actionDescription = String.format("Changing %s of %s from %s to %s", + attributeName, + path, + currentValue, + wantedValue.get()); + taskContext.logSystemModification(logger, actionDescription); + valueSetter.accept(wantedValue.get()); + return true; + } + } + + private boolean maybeUpdateContent(TaskContext taskContext, + Optional<String> content, + FileAttributesCache currentAttributes) { + if (!content.isPresent()) { + return false; + } + + if (!currentAttributes.exists()) { + taskContext.logSystemModification(logger, "Creating file " + path); + path.createParents(); + path.writeUtf8File(content.get()); + contentCache.updateWith(content.get(), currentAttributes.forceGet().lastModifiedTime()); + return true; + } + + if (Objects.equals(content.get(), contentCache.get(currentAttributes.get().lastModifiedTime()))) { + return false; + } else { + taskContext.logSystemModification(logger, "Patching file " + path); + path.writeUtf8File(content.get()); + contentCache.updateWith(content.get(), currentAttributes.forceGet().lastModifiedTime()); + return true; + } + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriter.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriter.java index 60a7b3482b2..58518ae5a15 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriter.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriter.java @@ -6,56 +6,53 @@ import org.glassfish.jersey.internal.util.Producer; import java.nio.file.Files; import java.nio.file.Path; -import java.util.Optional; -import java.util.logging.Logger; +/** + * Write a file + * + * @author hakonhall + */ public class FileWriter { - private static final Logger logger = Logger.getLogger(FileWriter.class.getName()); - private final Path path; + private final FileSync fileSync; + private final PartialFileData.Builder fileDataBuilder = PartialFileData.builder(); private final Producer<String> contentProducer; - private Optional<String> owner = Optional.empty(); - private Optional<String> group = Optional.empty(); - private Optional<String> permissions = Optional.empty(); + private boolean overwriteExistingFile = true; public FileWriter(Path path, Producer<String> contentProducer) { this.path = path; + this.fileSync = new FileSync(path); this.contentProducer = contentProducer; } public FileWriter withOwner(String owner) { - this.owner = Optional.of(owner); + fileDataBuilder.withOwner(owner); return this; } public FileWriter withGroup(String group) { - this.group = Optional.of(group); + fileDataBuilder.withGroup(group); return this; } public FileWriter withPermissions(String permissions) { - this.permissions = Optional.of(permissions); + fileDataBuilder.withPermissions(permissions); + return this; + } + + public FileWriter onlyIfFileDoesNotAlreadyExist() { + overwriteExistingFile = false; return this; } public boolean converge(TaskContext context) { - // TODO: Only return false if content, permission, etc would be unchanged. - if (Files.isRegularFile(path)) { + if (!overwriteExistingFile && Files.isRegularFile(path)) { return false; } - context.logSystemModification(logger,"Writing file " + path); - - String content = contentProducer.call(); - - UnixPath unixPath = new UnixPath(path); - unixPath.createParents(); - unixPath.writeUtf8File(content); - permissions.ifPresent(unixPath::setPermissions); - owner.ifPresent(unixPath::setOwner); - group.ifPresent(unixPath::setGroup); - - return true; + fileDataBuilder.withContent(contentProducer.call()); + PartialFileData fileData = fileDataBuilder.create(); + return fileSync.convergeTo(context, fileData); } } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/IOExceptionUtil.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/IOExceptionUtil.java index dee5525d42a..9bcf601c262 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/IOExceptionUtil.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/IOExceptionUtil.java @@ -3,7 +3,12 @@ package com.yahoo.vespa.hosted.node.admin.task.util.file; import java.io.IOException; import java.io.UncheckedIOException; +import java.nio.file.NoSuchFileException; +import java.util.Optional; +/** + * @author hakonhall + */ public class IOExceptionUtil { public static <T> void uncheck(RunnableThrowingIOException<T> runnable) { try { @@ -31,4 +36,20 @@ public class IOExceptionUtil { public interface RunnableThrowingIOException<T> { void run() throws IOException; } + + /** + * Useful if it's not known whether a file or directory exists, in case e.g. + * NoSuchFileException is thrown and the caller wants an Optional.empty() in that case. + */ + public static <T> Optional<T> ifExists(SupplierThrowingIOException<T> supplier) { + try { + return Optional.ofNullable(uncheck(supplier)); + } catch (UncheckedIOException e) { + if (e.getCause() instanceof NoSuchFileException) { + return Optional.empty(); + } + + throw e; + } + } } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/PartialFileData.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/PartialFileData.java new file mode 100644 index 00000000000..b931a374230 --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/PartialFileData.java @@ -0,0 +1,64 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import java.util.Optional; + +/** + * Represents a subset of a file's content, owner, group, and permissions. + * + * @author hakonhall + */ +// @Immutable +public class PartialFileData { + private final Optional<String> content; + private final Optional<String> owner; + private final Optional<String> group; + private final Optional<String> permissions; + + public static Builder builder() { + return new Builder(); + } + + public PartialFileData(Optional<String> content, + Optional<String> owner, + Optional<String> group, + Optional<String> permissions) { + this.content = content; + this.owner = owner; + this.group = group; + this.permissions = permissions; + } + + public Optional<String> getContent() { + return content; + } + + public Optional<String> getOwner() { + return owner; + } + + public Optional<String> getGroup() { + return group; + } + + public Optional<String> getPermissions() { + return permissions; + } + + public static class Builder { + private Optional<String> content = Optional.empty(); + private Optional<String> owner = Optional.empty(); + private Optional<String> group = Optional.empty(); + private Optional<String> permissions = Optional.empty(); + + public Builder withContent(String content) { this.content = Optional.of(content); return this; } + public Builder withOwner(String owner) { this.owner = Optional.of(owner); return this; } + public Builder withGroup(String group) { this.group = Optional.of(group); return this; } + public Builder withPermissions(String permissions) { this.permissions = Optional.of(permissions); return this; } + + public PartialFileData create() { + return new PartialFileData(content, owner, group, permissions); + } + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TemplateFile.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TemplateFile.java new file mode 100644 index 00000000000..e4dd5cf5d9c --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TemplateFile.java @@ -0,0 +1,47 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import org.apache.velocity.Template; +import org.apache.velocity.VelocityContext; +import org.apache.velocity.app.Velocity; +import org.apache.velocity.app.VelocityEngine; + +import java.io.StringWriter; +import java.nio.file.Path; + +/** + * Make a file based on a Velocity template file. + * + * @author hakonhall + */ +public class TemplateFile { + private final Path templatePath; + private final VelocityEngine velocityEngine; + private final VelocityContext velocityContext = new VelocityContext(); + + public TemplateFile(Path templatePath) { + this.templatePath = templatePath; + velocityEngine = new VelocityEngine(); + velocityEngine.addProperty( + Velocity.RUNTIME_LOG_LOGSYSTEM_CLASS, + "org.apache.velocity.runtime.log.NullLogSystem"); + velocityEngine.addProperty(Velocity.FILE_RESOURCE_LOADER_PATH, templatePath.getParent().toString()); + velocityEngine.init(); + } + + public TemplateFile set(String name, String value) { + velocityContext.put(name, value); + return this; + } + + public FileWriter getFileWriterTo(Path destinationPath) { + return new FileWriter(destinationPath, this::render); + } + + private String render() { + Template template = velocityEngine.getTemplate(templatePath.getFileName().toString(), "UTF-8"); + StringWriter writer = new StringWriter(); + template.merge(velocityContext, writer); + return writer.toString(); + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPath.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPath.java index 606f8cfb06e..aaffea05d1e 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPath.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPath.java @@ -1,8 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.task.util.file; -import java.io.IOException; -import java.io.UncheckedIOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.OpenOption; @@ -16,8 +14,16 @@ import java.nio.file.attribute.PosixFilePermissions; import java.nio.file.attribute.UserPrincipal; import java.nio.file.attribute.UserPrincipalLookupService; import java.time.Instant; +import java.util.Optional; import java.util.Set; +import static com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil.uncheck; + +/** + * Thin wrapper around java.nio.file.Path, especially nice for UNIX-specific features. + * + * @author hakonhall + */ // @Immutable public class UnixPath { private final Path path; @@ -34,8 +40,14 @@ public class UnixPath { return path; } - public void createParents() { - uncheck(() -> Files.createDirectories(path.getParent())); + public boolean createParents() { + Path parent = path.getParent(); + if (Files.isDirectory(parent)) { + return false; + } + + uncheck(() -> Files.createDirectories(parent)); + return true; } public String readUtf8File() { @@ -49,7 +61,7 @@ public class UnixPath { } public String getPermissions() { - return PosixFilePermissions.toString(getAttributes().permissions()); + return getAttributes().permissions(); } /** @@ -69,7 +81,7 @@ public class UnixPath { } public String getOwner() { - return getAttributes().owner().getName(); + return getAttributes().owner(); } public void setOwner(String owner) { @@ -79,7 +91,7 @@ public class UnixPath { } public String getGroup() { - return getAttributes().group().getName(); + return getAttributes().group(); } public void setGroup(String group) { @@ -89,37 +101,21 @@ public class UnixPath { } public Instant getLastModifiedTime() { - return uncheck(() -> Files.getLastModifiedTime(path)).toInstant(); + return getAttributes().lastModifiedTime(); } - private PosixFileAttributes getAttributes() { - return uncheck(() -> + public FileAttributes getAttributes() { + PosixFileAttributes attributes = uncheck(() -> Files.getFileAttributeView(path, PosixFileAttributeView.class).readAttributes()); + return new FileAttributes(attributes); } - @FunctionalInterface - private interface SupplierThrowingIOException<T> { - T get() throws IOException; + public Optional<FileAttributes> getAttributesIfExists() { + return IOExceptionUtil.ifExists(() -> getAttributes()); } - private static <T> T uncheck(SupplierThrowingIOException<T> supplier) { - try { - return supplier.get(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - @FunctionalInterface - private interface RunnableThrowingIOException<T> { - void run() throws IOException; - } - - private static <T> void uncheck(RunnableThrowingIOException<T> runnable) { - try { - runnable.run(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } + @Override + public String toString() { + return path.toString(); } } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ChildProcess.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ChildProcess.java new file mode 100644 index 00000000000..00bcca71970 --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ChildProcess.java @@ -0,0 +1,20 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.task.util.process; + +import java.nio.file.Path; + +/** + * @author hakonhall + */ +public interface ChildProcess extends AutoCloseable { + ChildProcess waitForTermination(); + int exitValue(); + ChildProcess throwIfFailed(); + String getUtf8Output(); + + @Override + void close(); + + // For testing only + Path getProcessOutputPath(); +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ChildProcessImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ChildProcessImpl.java new file mode 100644 index 00000000000..367688f0bb4 --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/ChildProcessImpl.java @@ -0,0 +1,80 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.task.util.process; + +import com.yahoo.vespa.hosted.node.admin.task.util.file.UnixPath; + +import java.nio.file.Path; + +/** + * Represents a forked child process that still exists or has terminated. + * + * @author hakonhall + */ +public class ChildProcessImpl implements ChildProcess { + private final Process process; + private final Path processOutputPath; + private final String commandLine; + + ChildProcessImpl(Process process, Path processOutputPath, String commandLine) { + this.process = process; + this.processOutputPath = processOutputPath; + this.commandLine = commandLine; + } + + public String getUtf8Output() { + return new UnixPath(processOutputPath).readUtf8File(); + } + + public ChildProcessImpl waitForTermination() { + while (true) { + try { + process.waitFor(); + } catch (InterruptedException e) { + // ignoring + continue; + } + + return this; + } + } + + public int exitValue() { + return process.exitValue(); + } + + public ChildProcess throwIfFailed() { + if (process.exitValue() != 0) { + throw new CommandException("Execution of program [" + commandLine + + "] failed, stdout/stderr was: <" + suffixOfOutputForLog() + ">"); + } + + return this; + } + + private String suffixOfOutputForLog() { + String output = getUtf8Output(); + + final int maxTrailingChars = 300; + if (output.length() <= maxTrailingChars) { + return output; + } + + int numSkippedChars = output.length() - maxTrailingChars; + output = output.substring(numSkippedChars); + return "[" + numSkippedChars + " chars omitted]..." + output; + } + + @Override + public void close() { + if (process.isAlive()) { + process.destroyForcibly(); + waitForTermination(); + } + processOutputPath.toFile().delete(); + } + + @Override + public Path getProcessOutputPath() { + return processOutputPath; + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/Command.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/Command.java new file mode 100644 index 00000000000..049490f2705 --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/Command.java @@ -0,0 +1,95 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.task.util.process; + +import com.yahoo.vespa.hosted.node.admin.component.TaskContext; +import com.yahoo.vespa.hosted.node.admin.task.util.file.IOExceptionUtil; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.logging.Logger; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** + * Class to fork and exec a program, and gets its exit status and output. + * + * @author hakonhall + */ +public class Command { + private static Logger logger = Logger.getLogger(Command.class.getName()); + private static Pattern ARGUMENT_PATTERN = Pattern.compile("^[a-zA-Z0-9=@%/+:.,_-]+$"); + + private final TaskContext context; + private final List<String> arguments = new ArrayList<>(); + + public Command(TaskContext context) { + this.context = context; + } + + public Command add(String... arguments) { return add(Arrays.asList(arguments)); } + public Command add(List<String> arguments) { + this.arguments.addAll(arguments); + return this; + } + + public ChildProcess spawn(Logger commandLogger) { + if (arguments.isEmpty()) { + throw new IllegalStateException("No program has been specified"); + } + + String commandLine = commandLine(); + if (commandLogger != null) { + context.logSystemModification(commandLogger, "Executing command: " + commandLine); + } + + // Why isn't this using TaskContext.fileSystem? Because createTempFile assumes + // default FileSystem. And Jimfs doesn't support toFile() needed for Redirect below. + Path temporaryFile = IOExceptionUtil.uncheck(() -> Files.createTempFile( + Command.class.getSimpleName() + "-", + ".out")); + + ProcessBuilder builder = new ProcessBuilder(arguments) + .redirectError(ProcessBuilder.Redirect.appendTo(temporaryFile.toFile())) + .redirectOutput(temporaryFile.toFile()); + Process process = IOExceptionUtil.uncheck(builder::start); + + return new ChildProcessImpl(process, temporaryFile, commandLine); + } + + String commandLine() { + return arguments.stream().map(Command::maybeEscapeArgument).collect(Collectors.joining(" ")); + } + + private static String maybeEscapeArgument(String argument) { + if (ARGUMENT_PATTERN.matcher(argument).matches()) { + return argument; + } else { + return escapeArgument(argument); + } + } + + private static String escapeArgument(String argument) { + StringBuilder doubleQuoteEscaped = new StringBuilder(argument.length() + 10); + + for (int i = 0; i < argument.length(); ++i) { + char c = argument.charAt(i); + switch (c) { + case '"': + case '\\': + doubleQuoteEscaped.append("\\").append(c); + break; + default: + doubleQuoteEscaped.append(c); + } + } + + return "\"" + doubleQuoteEscaped + "\""; + } + + public ChildProcess spawnWithoutLoggingCommand() { + return spawn(null); + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/CommandException.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/CommandException.java new file mode 100644 index 00000000000..148f2102ddf --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/CommandException.java @@ -0,0 +1,12 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.task.util.process; + +/** + * @author hakonhall + */ +@SuppressWarnings("serial") +public class CommandException extends RuntimeException { + public CommandException(String message) { + super(message); + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/package-info.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/package-info.java new file mode 100644 index 00000000000..16da9a3b7ca --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/process/package-info.java @@ -0,0 +1,5 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +package com.yahoo.vespa.hosted.node.admin.task.util.process; + +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepo.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepo.java index 9ca1c0286f9..2b1cbbed974 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepo.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepo.java @@ -7,8 +7,11 @@ import com.yahoo.vespa.hosted.node.admin.task.util.file.FileWriter; import java.nio.file.Path; import java.util.regex.Pattern; +/** + * @author hakonhall + */ public class AddYumRepo { - private static final Pattern REPOSITORY_ID_PATTERN = Pattern.compile("^[a-zA-Z_-]+$"); + private static final Pattern REPOSITORY_ID_PATTERN = Pattern.compile("^[a-zA-Z0-9_-]+$"); private final String repositoryId; // e.g. "platform_rpms-latest" private final String name; // e.g. "Platform RPM Latest Repo" @@ -32,7 +35,8 @@ public class AddYumRepo { FileWriter fileWriter = new FileWriter(path, this::getRepoFileContent) .withOwner("root") .withGroup("root") - .withPermissions("rw-r--r--"); + .withPermissions("rw-r--r--") + .onlyIfFileDoesNotAlreadyExist(); return fileWriter.converge(context); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java new file mode 100644 index 00000000000..c1514f1056b --- /dev/null +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/Yum.java @@ -0,0 +1,86 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.task.util.yum; + +import com.yahoo.vespa.hosted.node.admin.component.TaskContext; +import com.yahoo.vespa.hosted.node.admin.task.util.process.ChildProcess; +import com.yahoo.vespa.hosted.node.admin.task.util.process.Command; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; +import java.util.logging.Logger; + +/** + * @author hakonhall + */ +public class Yum { + private static Logger logger = Logger.getLogger(Yum.class.getName()); + + private final TaskContext taskContext; + private final Supplier<Command> commandSupplier; + private List<String> packages = new ArrayList<>(); + + public Yum(TaskContext taskContext) { + this.taskContext = taskContext; + this.commandSupplier = () -> new Command(taskContext); + } + + /** + * @param packages A list of packages, each package being of the form name-1.2.3-1.el7.noarch + */ + public Install install(String... packages) { + return new Install(taskContext, Arrays.asList(packages)); + } + + public class Install { + private final TaskContext taskContext; + private final List<String> packages; + private Optional<String> enabledRepo = Optional.empty(); + + public Install(TaskContext taskContext, List<String> packages) { + this.taskContext = taskContext; + this.packages = packages; + + if (packages.isEmpty()) { + throw new IllegalArgumentException("No packages specified"); + } + } + + public Install enableRepo(String repo) { + enabledRepo = Optional.of(repo); + return this; + } + + public boolean converge() { + if (packages.stream().allMatch(Yum.this::isInstalled)) { + return false; + } + + execute(); + return true; + } + + private void execute() { + Command command = commandSupplier.get(); + command.add("yum", "install", "--assumeyes"); + enabledRepo.ifPresent(repo -> command.add("--enablerepo=" + repo)); + command.add(packages); + command.spawn(logger).waitForTermination().throwIfFailed(); + } + } + + Yum(TaskContext taskContext, Supplier<Command> commandSupplier) { + this.taskContext = taskContext; + this.commandSupplier = commandSupplier; + } + + private boolean isInstalled(String package_) { + ChildProcess childProcess = commandSupplier.get() + .add("yum", "list", "installed", package_) + .spawnWithoutLoggingCommand(); + childProcess.waitForTermination(); + return childProcess.exitValue() == 0; + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/KeyStoreOptions.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/KeyStoreOptions.java index fbcaf701c6f..84db5840909 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/KeyStoreOptions.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/KeyStoreOptions.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.util; import java.io.FileInputStream; diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/SelfCloseableHttpClient.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/SelfCloseableHttpClient.java index ddb473d348c..8e516729aff 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/SelfCloseableHttpClient.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/SelfCloseableHttpClient.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.util; import com.yahoo.log.LogLevel; diff --git a/node-admin/src/main/resources/configdefinitions/config-server.def b/node-admin/src/main/resources/configdefinitions/config-server.def index e4265e618a1..0de79160277 100644 --- a/node-admin/src/main/resources/configdefinitions/config-server.def +++ b/node-admin/src/main/resources/configdefinitions/config-server.def @@ -1,4 +1,4 @@ -# Copyright 2018 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. namespace=vespa.hosted.node.admin.config hosts[] string diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/provider/DebugHandlerHelperTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/provider/DebugHandlerHelperTest.java new file mode 100644 index 00000000000..723b9f0df8a --- /dev/null +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/provider/DebugHandlerHelperTest.java @@ -0,0 +1,35 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.provider; + +import org.junit.Test; + +import java.util.Collections; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +public class DebugHandlerHelperTest { + @Test + public void trivial() { + DebugHandlerHelper helper = new DebugHandlerHelper(); + helper.addConstant("constant-key", "constant-value"); + + NodeAdminDebugHandler handler = new NodeAdminDebugHandler() { + @Override + public Map<String, Object> getDebugPage() { + return Collections.singletonMap("handler-value-key", "handler-value-value"); + } + }; + helper.addHandler("handler-key", handler); + + helper.addThreadSafeSupplier("supplier-key", () -> "supplier-value"); + + assertEquals("{" + + "supplier-key=supplier-value, " + + "handler-key={handler-value-key=handler-value-value}, " + + "constant-key=constant-value" + + "}", + helper.getDebugPage().toString()); + } +}
\ No newline at end of file diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCacheTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCacheTest.java new file mode 100644 index 00000000000..9224faf1c6f --- /dev/null +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCacheTest.java @@ -0,0 +1,38 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import org.junit.Test; + +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class FileAttributesCacheTest { + @Test + public void exists() throws Exception { + UnixPath unixPath = mock(UnixPath.class); + FileAttributesCache cache = new FileAttributesCache(unixPath); + + when(unixPath.getAttributesIfExists()).thenReturn(Optional.empty()); + assertFalse(cache.exists()); + verify(unixPath, times(1)).getAttributesIfExists(); + verifyNoMoreInteractions(unixPath); + + FileAttributes attributes = mock(FileAttributes.class); + when(unixPath.getAttributesIfExists()).thenReturn(Optional.of(attributes)); + assertTrue(cache.exists()); + verify(unixPath, times(1 + 1)).getAttributesIfExists(); + verifyNoMoreInteractions(unixPath); + + assertEquals(attributes, cache.get()); + verifyNoMoreInteractions(unixPath); + } +}
\ No newline at end of file diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileContentCacheTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileContentCacheTest.java new file mode 100644 index 00000000000..677dd048445 --- /dev/null +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileContentCacheTest.java @@ -0,0 +1,58 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import org.junit.Test; + +import java.time.Instant; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class FileContentCacheTest { + private final UnixPath unixPath = mock(UnixPath.class); + private final FileContentCache cache = new FileContentCache(unixPath); + + @Test + public void get() throws Exception { + when(unixPath.readUtf8File()).thenReturn("content"); + assertEquals("content", cache.get(Instant.ofEpochMilli(0))); + verify(unixPath, times(1)).readUtf8File(); + verifyNoMoreInteractions(unixPath); + + // cache hit + assertEquals("content", cache.get(Instant.ofEpochMilli(0))); + verify(unixPath, times(1)).readUtf8File(); + verifyNoMoreInteractions(unixPath); + + // cache miss + when(unixPath.readUtf8File()).thenReturn("new-content"); + assertEquals("new-content", cache.get(Instant.ofEpochMilli(1))); + verify(unixPath, times(1 + 1)).readUtf8File(); + verifyNoMoreInteractions(unixPath); + + // cache hit both at times 0 and 1 + assertEquals("new-content", cache.get(Instant.ofEpochMilli(0))); + verify(unixPath, times(1 + 1)).readUtf8File(); + verifyNoMoreInteractions(unixPath); + assertEquals("new-content", cache.get(Instant.ofEpochMilli(1))); + verify(unixPath, times(1 + 1)).readUtf8File(); + verifyNoMoreInteractions(unixPath); + } + + @Test + public void updateWith() throws Exception { + cache.updateWith("content", Instant.ofEpochMilli(2)); + assertEquals("content", cache.get(Instant.ofEpochMilli(2))); + verifyNoMoreInteractions(unixPath); + + cache.updateWith("new-content", Instant.ofEpochMilli(4)); + assertEquals("new-content", cache.get(Instant.ofEpochMilli(4))); + verifyNoMoreInteractions(unixPath); + } + +}
\ No newline at end of file diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileSyncTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileSyncTest.java new file mode 100644 index 00000000000..44868e17464 --- /dev/null +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileSyncTest.java @@ -0,0 +1,84 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import com.yahoo.vespa.test.file.TestFileSystem; +import org.junit.Test; + +import java.nio.file.FileSystem; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class FileSyncTest { + private final TestTaskContext taskContext = new TestTaskContext(); + private final FileSystem fileSystem = TestFileSystem.create(); + + private final Path path = fileSystem.getPath("/dir/file.txt"); + private final UnixPath unixPath = new UnixPath(path); + private final FileSync fileSync = new FileSync(path); + + private String content = "content"; + private String owner = "owner"; // default is user + private String group = "group1"; // default is group + private String permissions = "rw-r-xr--"; + + @Test + public void trivial() { + assertConvergence("Creating file /dir/file.txt", + "Changing owner of /dir/file.txt from user to owner", + "Changing group of /dir/file.txt from group to group1", + "Changing permissions of /dir/file.txt from rw-r--r-- to rw-r-xr--"); + + content = "new-content"; + assertConvergence("Patching file /dir/file.txt"); + + owner = "new-owner"; + assertConvergence("Changing owner of /dir/file.txt from owner to " + + owner); + + group = "new-group1"; + assertConvergence("Changing group of /dir/file.txt from group1 to new-group1"); + + permissions = "rwxr--rwx"; + assertConvergence("Changing permissions of /dir/file.txt from rw-r-xr-- to " + + permissions); + } + + private void assertConvergence(String... systemModificationMessages) { + PartialFileData fileData = PartialFileData.builder() + .withContent(content) + .withOwner(owner) + .withGroup(group) + .withPermissions(permissions) + .create(); + taskContext.clearSystemModificationLog(); + assertTrue(fileSync.convergeTo(taskContext, fileData)); + + assertTrue(Files.isRegularFile(path)); + fileData.getContent().ifPresent(content -> assertEquals(content, unixPath.readUtf8File())); + fileData.getOwner().ifPresent(owner -> assertEquals(owner, unixPath.getOwner())); + fileData.getGroup().ifPresent(group -> assertEquals(group, unixPath.getGroup())); + fileData.getPermissions().ifPresent(permissions -> assertEquals(permissions, unixPath.getPermissions())); + + List<String> actualMods = taskContext.getSystemModificationLog(); + List<String> expectedMods = Arrays.asList(systemModificationMessages); + assertEquals(expectedMods, actualMods); + + UnixPath unixPath = new UnixPath(path); + Instant lastModifiedTime = unixPath.getLastModifiedTime(); + taskContext.clearSystemModificationLog(); + assertFalse(fileSync.convergeTo(taskContext, fileData)); + assertEquals(lastModifiedTime, unixPath.getLastModifiedTime()); + + actualMods = taskContext.getSystemModificationLog(); + assertEquals(new ArrayList<>(), actualMods); + } +} diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriterTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriterTest.java index ca4eabf855b..bb8ca2586c8 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriterTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileWriterTest.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.node.admin.task.util.file; +import com.yahoo.vespa.test.file.TestFileSystem; import com.yahoo.vespa.hosted.node.admin.component.TaskContext; import org.junit.Test; @@ -32,10 +33,11 @@ public class FileWriterTest { FileWriter writer = new FileWriter(path, () -> content) .withPermissions(permissions) .withOwner(owner) - .withGroup(group); + .withGroup(group) + .onlyIfFileDoesNotAlreadyExist(); TaskContext context = mock(TaskContext.class); assertTrue(writer.converge(context)); - verify(context, times(1)).logSystemModification(any(), eq("Writing file " + path)); + verify(context, times(1)).logSystemModification(any(), eq("Creating file " + path)); UnixPath unixPath = new UnixPath(path); assertEquals(content, unixPath.readUtf8File()); diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TemplateFileTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TemplateFileTest.java new file mode 100644 index 00000000000..b1d88fdaaee --- /dev/null +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TemplateFileTest.java @@ -0,0 +1,49 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import com.yahoo.vespa.hosted.node.admin.component.TaskContext; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.IOException; +import java.nio.file.Path; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +/** + * WARNING: Velocity does not honor an alternative FileSystem like JimFS. + */ +public class TemplateFileTest { + @Rule + public TemporaryFolder folder = new TemporaryFolder(); + + private void writeFile(Path path, String content) { + UnixPath unixPath = new UnixPath(path); + unixPath.createParents(); + unixPath.writeUtf8File(content); + } + + @Test + public void basic() throws IOException { + String templateContent = "a $x, $y b"; + Path templatePath = folder.newFile("example.vm").toPath(); + writeFile(templatePath, templateContent); + + Path toPath = folder.newFile().toPath(); + TaskContext taskContext = mock(TaskContext.class); + boolean converged = new TemplateFile(templatePath) + .set("x", "foo") + .set("y", "bar") + .getFileWriterTo(toPath) + .converge(taskContext); + + assertTrue(converged); + + String actualContent = new UnixPath(toPath).readUtf8File(); + assertEquals("a foo, bar b", actualContent); + } +}
\ No newline at end of file diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TestTaskContext.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TestTaskContext.java new file mode 100644 index 00000000000..757f3004683 --- /dev/null +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TestTaskContext.java @@ -0,0 +1,49 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.hosted.node.admin.task.util.file; + +import com.yahoo.vespa.hosted.node.admin.component.IdempotentTask; +import com.yahoo.vespa.hosted.node.admin.component.TaskContext; + +import java.nio.file.FileSystem; +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.List; +import java.util.logging.Logger; + +public class TestTaskContext implements TaskContext { + private final List<String> logs = new ArrayList<>(); + + @Override + public Cloud cloud() { + throw new UnsupportedOperationException(); + } + + @Override + public EnumSet<Role> roles() { + throw new UnsupportedOperationException(); + } + + @Override + public FileSystem fileSystem() { + throw new UnsupportedOperationException(); + } + + @Override + public void logSystemModification(Logger logger, String actionDescription) { + logs.add(actionDescription); + } + + public List<String> getSystemModificationLog() { + return logs; + } + + public void clearSystemModificationLog() { + logs.clear(); + } + + @Override + public boolean executeSubtask(IdempotentTask task) { + throw new UnsupportedOperationException(); + } +} diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPathTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPathTest.java index 821c6397ee7..bd29f239e1d 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPathTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/UnixPathTest.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.node.admin.task.util.file; +import com.yahoo.vespa.test.file.TestFileSystem; import org.junit.Test; import java.nio.file.FileSystem; diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/process/CommandTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/process/CommandTest.java new file mode 100644 index 00000000000..373c75eba59 --- /dev/null +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/process/CommandTest.java @@ -0,0 +1,87 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.task.util.process; + +import com.yahoo.vespa.hosted.node.admin.component.TaskContext; +import org.junit.Test; + +import java.io.UncheckedIOException; +import java.nio.file.Path; +import java.util.logging.Logger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class CommandTest { + @Test + public void testCommand() { + TaskContext taskContext = mock(TaskContext.class); + Logger logger = mock(Logger.class); + + Command command = new Command(taskContext).add("bash", "-c", "ls /bin/bash"); + Path outputFile; + // Assumes /bin/bash exists on all hosts running this test. + try (ChildProcess childProcess = command.spawn(logger)) { + verify(taskContext).logSystemModification(eq(logger), any()); + + outputFile = childProcess.getProcessOutputPath(); + int exitValue = childProcess.waitForTermination().exitValue(); + assertEquals(0, exitValue); + childProcess.throwIfFailed(); + String output = childProcess.getUtf8Output().trim(); + assertEquals("/bin/bash", output); + assertTrue(outputFile.toFile().exists()); + } + + assertFalse(outputFile.toFile().exists()); + } + + @Test(expected = UncheckedIOException.class) + public void noSuchProgram() { + TaskContext taskContext = mock(TaskContext.class); + Logger logger = mock(Logger.class); + + Command command = new Command(taskContext).add("thisprogRamDoes-not-exist"); + try (ChildProcess childProcess = command.spawn(logger)) { + dummyToRemoveWarning(childProcess); + } + + fail(); + } + + private void dummyToRemoveWarning(ChildProcess childProcess) { } + + @Test + public void argumentEscape() { + TaskContext taskContext = mock(TaskContext.class); + Command command = new Command(taskContext).add("b", "\" \\ foo", "bar x", ""); + assertEquals("b \"\\\" \\\\ foo\" \"bar x\" \"\"", command.commandLine()); + } + + @Test + public void failingProgram() { + TaskContext taskContext = mock(TaskContext.class); + Logger logger = mock(Logger.class); + + Command command = new Command(taskContext) + .add("bash", "-c", "echo foo; echo bar >&2; exit 1"); + Path outputFile; + try (ChildProcess childProcess = command.spawn(logger)) { + try { + childProcess.waitForTermination().throwIfFailed(); + fail(); + } catch (CommandException e) { + assertEquals("Execution of program [bash -c \"echo foo; echo bar >&2; exit 1\"] failed, stdout/stderr was: <foo\n" + + "bar\n" + + ">", + e.getMessage()); + } + } + + } +}
\ No newline at end of file diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/process/TestCommand.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/process/TestCommand.java new file mode 100644 index 00000000000..59c853f949d --- /dev/null +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/process/TestCommand.java @@ -0,0 +1,71 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.task.util.process; + +import com.yahoo.vespa.hosted.node.admin.component.TaskContext; + +import java.nio.file.Path; +import java.util.logging.Logger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +public class TestCommand extends Command { + private final String expectedCommandLine; + private final ChildProcess childProcess; + + private boolean invoked = false; + + public TestCommand(TaskContext context, + String expectedCommandLine, + int exitValue, + String out) { + super(context); + this.expectedCommandLine = expectedCommandLine; + this.childProcess = new ChildProcess() { + @Override + public ChildProcess waitForTermination() { + return this; + } + + @Override + public int exitValue() { + return exitValue; + } + + @Override + public ChildProcess throwIfFailed() { + if (exitValue != 0) { + throw new CommandException("exited with " + exitValue); + } + return this; + } + + @Override + public String getUtf8Output() { + return out; + } + + @Override + public void close() { } + + @Override + public Path getProcessOutputPath() { return null; } + }; + } + + @Override + public ChildProcess spawn(Logger commandLogger) { + assertFalse(invoked); + invoked = true; + + assertEquals(expectedCommandLine, commandLine()); + + return childProcess; + } + + public void verifyInvocation() { + if (!invoked) { + throw new IllegalStateException("Command not invoked: " + expectedCommandLine); + } + } +} diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/process/TestCommandSupplier.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/process/TestCommandSupplier.java new file mode 100644 index 00000000000..1c900604260 --- /dev/null +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/process/TestCommandSupplier.java @@ -0,0 +1,43 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.task.util.process; + +import com.yahoo.vespa.hosted.node.admin.component.TaskContext; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +public class TestCommandSupplier implements Supplier<Command> { + private final TaskContext taskContext; + private final List<TestCommand> expectedInvocations = new ArrayList<>(); + private int index = 0; + + public TestCommandSupplier(TaskContext taskContext) { + this.taskContext = taskContext; + } + + public TestCommandSupplier expectCommand(String commandLine, int exitValue, String out) { + expectedInvocations.add(new TestCommand(taskContext, commandLine, exitValue, out)); + return this; + } + + @Override + public Command get() { + if (index >= expectedInvocations.size()) { + throw new IllegalStateException("Too many command invocations"); + } + + return expectedInvocations.get(index++); + } + + public void verifyInvocations() { + if (index != expectedInvocations.size()) { + throw new IllegalStateException("Received only " + index + + " command invocations: expected " + expectedInvocations.size()); + } + + for (int i = 0; i < index; ++i) { + expectedInvocations.get(i).verifyInvocation(); + } + } +} diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepoTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepoTest.java index 7b6ab91345b..ad1fefe782f 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepoTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/AddYumRepoTest.java @@ -3,7 +3,7 @@ package com.yahoo.vespa.hosted.node.admin.task.util.yum; import com.yahoo.vespa.hosted.node.admin.component.TaskContext; -import com.yahoo.vespa.hosted.node.admin.task.util.file.TestFileSystem; +import com.yahoo.vespa.test.file.TestFileSystem; import com.yahoo.vespa.hosted.node.admin.task.util.file.UnixPath; import org.junit.Test; diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java new file mode 100644 index 00000000000..d852be26229 --- /dev/null +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/yum/YumTest.java @@ -0,0 +1,67 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.node.admin.task.util.yum; + +import com.yahoo.vespa.hosted.node.admin.component.TaskContext; +import com.yahoo.vespa.hosted.node.admin.task.util.process.CommandException; +import com.yahoo.vespa.hosted.node.admin.task.util.process.TestCommandSupplier; +import org.junit.Test; + +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +public class YumTest { + @Test + public void testAlreadyInstalled() { + TaskContext taskContext = mock(TaskContext.class); + TestCommandSupplier commandSupplier = new TestCommandSupplier(taskContext); + + commandSupplier.expectCommand("yum list installed package-1", 0, ""); + commandSupplier.expectCommand("yum list installed package-2", 0, ""); + + Yum yum = new Yum(taskContext, commandSupplier); + yum.install("package-1", "package-2") + .enableRepo("repo-name") + .converge(); + + commandSupplier.verifyInvocations(); + } + + @Test + public void testInstall() { + TaskContext taskContext = mock(TaskContext.class); + TestCommandSupplier commandSupplier = new TestCommandSupplier(taskContext); + + commandSupplier.expectCommand("yum list installed package-1", 0, ""); + commandSupplier.expectCommand("yum list installed package-2", 1, ""); + commandSupplier.expectCommand( + "yum install --assumeyes --enablerepo=repo-name package-1 package-2", + 0, + ""); + + Yum yum = new Yum(taskContext, commandSupplier); + yum.install("package-1", "package-2") + .enableRepo("repo-name") + .converge(); + + commandSupplier.verifyInvocations(); + } + + @Test(expected = CommandException.class) + public void testFailedInstall() { + TaskContext taskContext = mock(TaskContext.class); + TestCommandSupplier commandSupplier = new TestCommandSupplier(taskContext); + + commandSupplier.expectCommand("yum list installed package-1", 0, ""); + commandSupplier.expectCommand("yum list installed package-2", 1, ""); + commandSupplier.expectCommand( + "yum install --assumeyes --enablerepo=repo-name package-1 package-2", + 1, + "error"); + + Yum yum = new Yum(taskContext, commandSupplier); + yum.install("package-1", "package-2") + .enableRepo("repo-name") + .converge(); + fail(); + } +}
\ No newline at end of file diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java index 19894050ff4..c26e59a0b1a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java @@ -391,9 +391,8 @@ public class NodeRepository extends AbstractComponent { Node nodeToDirty = getNode(hostname, Node.State.provisioned, Node.State.failed, Node.State.parked).orElseThrow(() -> new IllegalArgumentException("Could not deallocate " + hostname + ": No such node in the provisioned, failed or parked state")); - if (nodeToDirty.status().hardwareFailureDescription().isPresent() || nodeToDirty.status().hardwareDivergence().isPresent()) - throw new IllegalArgumentException("Could not deallocate " + hostname + ": It has a hardware failure/spec divergence"); - + if (nodeToDirty.status().hardwareFailureDescription().isPresent()) + throw new IllegalArgumentException("Could not deallocate " + hostname + ": It has a hardware failure"); return setDirty(nodeToDirty); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java index c23a7f9990a..bdea767eb0d 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/RestApiTest.java @@ -282,7 +282,7 @@ public class RestApiTest { "{\"message\":\"Moved host12.yahoo.com to failed\"}"); assertResponse(new Request("http://localhost:8080/nodes/v2/state/dirty/host12.yahoo.com", new byte[0], Request.Method.PUT), 400, - "{\"error-code\":\"BAD_REQUEST\",\"message\":\"Could not deallocate host12.yahoo.com: It has a hardware failure/spec divergence\"}"); + "{\"error-code\":\"BAD_REQUEST\",\"message\":\"Could not deallocate host12.yahoo.com: It has a hardware failure\"}"); } @Test diff --git a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ApplicationApi.java b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ApplicationApi.java index 7e044379137..0ca509d13f1 100644 --- a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ApplicationApi.java +++ b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ApplicationApi.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.orchestrator.model; +import com.yahoo.config.provision.ApplicationId; import com.yahoo.vespa.applicationmodel.HostName; import com.yahoo.vespa.orchestrator.status.ApplicationInstanceStatus; import com.yahoo.vespa.orchestrator.status.HostStatus; @@ -11,7 +12,10 @@ import java.util.List; * The API a Policy has access to */ public interface ApplicationApi { - String applicationInfo(); + /** + * @return The 3-part application ID of the form tenant:name:instance. + */ + ApplicationId applicationId(); /** * The policy acts on some subset of nodes in the application. diff --git a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ApplicationApiImpl.java b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ApplicationApiImpl.java index e280341d02c..c5bcaf4de82 100644 --- a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ApplicationApiImpl.java +++ b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ApplicationApiImpl.java @@ -1,10 +1,12 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.orchestrator.model; +import com.yahoo.config.provision.ApplicationId; import com.yahoo.vespa.applicationmodel.ApplicationInstance; import com.yahoo.vespa.applicationmodel.HostName; import com.yahoo.vespa.applicationmodel.ServiceCluster; import com.yahoo.vespa.applicationmodel.ServiceInstance; +import com.yahoo.vespa.orchestrator.OrchestratorUtil; import com.yahoo.vespa.orchestrator.controller.ClusterControllerClientFactory; import com.yahoo.vespa.orchestrator.status.ApplicationInstanceStatus; import com.yahoo.vespa.orchestrator.status.HostStatus; @@ -44,8 +46,8 @@ public class ApplicationApiImpl implements ApplicationApi { } @Override - public String applicationInfo() { - return applicationInstance.reference().toString(); + public ApplicationId applicationId() { + return OrchestratorUtil.toApplicationId(applicationInstance.reference()); } private static Map<HostName, HostStatus> createHostStatusMap(Collection<HostName> hosts, @@ -113,13 +115,14 @@ public class ApplicationApiImpl implements ApplicationApi { .collect(Collectors.toList()); } - private static List<ClusterApi> makeClustersInOrder + private List<ClusterApi> makeClustersInOrder (NodeGroup nodeGroup, Map<HostName, HostStatus> hostStatusMap, ClusterControllerClientFactory clusterControllerClientFactory) { Set<ServiceCluster> clustersInGroup = getServiceClustersInGroup(nodeGroup); return clustersInGroup.stream() .map(serviceCluster -> new ClusterApiImpl( + this, serviceCluster, nodeGroup, hostStatusMap, diff --git a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ClusterApi.java b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ClusterApi.java index 1b536cce3b5..025d21316ef 100644 --- a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ClusterApi.java +++ b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ClusterApi.java @@ -7,6 +7,8 @@ import com.yahoo.vespa.applicationmodel.ServiceType; import java.util.Optional; public interface ClusterApi { + ApplicationApi getApplication(); + NodeGroup getNodeGroup(); String clusterInfo(); diff --git a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ClusterApiImpl.java b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ClusterApiImpl.java index dfa2610a130..d0217710bdb 100644 --- a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ClusterApiImpl.java +++ b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/ClusterApiImpl.java @@ -5,10 +5,10 @@ import com.yahoo.vespa.applicationmodel.ClusterId; import com.yahoo.vespa.applicationmodel.HostName; import com.yahoo.vespa.applicationmodel.ServiceCluster; import com.yahoo.vespa.applicationmodel.ServiceInstance; +import com.yahoo.vespa.applicationmodel.ServiceStatus; import com.yahoo.vespa.applicationmodel.ServiceType; import com.yahoo.vespa.orchestrator.controller.ClusterControllerClientFactory; import com.yahoo.vespa.orchestrator.status.HostStatus; -import com.yahoo.vespa.applicationmodel.ServiceStatus; import java.util.Collections; import java.util.HashSet; @@ -19,6 +19,7 @@ import java.util.function.Predicate; import java.util.stream.Collectors; class ClusterApiImpl implements ClusterApi { + private final ApplicationApi applicationApi; private final ServiceCluster serviceCluster; private final NodeGroup nodeGroup; private final Map<HostName, HostStatus> hostStatusMap; @@ -28,10 +29,12 @@ class ClusterApiImpl implements ClusterApi { private final Set<ServiceInstance> servicesNotInGroup; private final Set<ServiceInstance> servicesDownAndNotInGroup; - public ClusterApiImpl(ServiceCluster serviceCluster, + public ClusterApiImpl(ApplicationApi applicationApi, + ServiceCluster serviceCluster, NodeGroup nodeGroup, Map<HostName, HostStatus> hostStatusMap, ClusterControllerClientFactory clusterControllerClientFactory) { + this.applicationApi = applicationApi; this.serviceCluster = serviceCluster; this.nodeGroup = nodeGroup; this.hostStatusMap = hostStatusMap; @@ -71,6 +74,11 @@ class ClusterApiImpl implements ClusterApi { } @Override + public ApplicationApi getApplication() { + return applicationApi; + } + + @Override public boolean noServicesInGroupIsUp() { return servicesDownInGroup.size() == servicesInGroup.size(); } diff --git a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/VespaModelUtil.java b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/VespaModelUtil.java index 40556eb2f40..4dd8133d2e9 100644 --- a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/VespaModelUtil.java +++ b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/model/VespaModelUtil.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.orchestrator.model; +import com.yahoo.config.provision.ApplicationId; import com.yahoo.vespa.applicationmodel.ApplicationInstance; import com.yahoo.vespa.applicationmodel.ClusterId; import com.yahoo.vespa.applicationmodel.ConfigId; @@ -29,10 +30,13 @@ import static com.yahoo.collections.CollectionUtil.first; * @author hakonhall */ public class VespaModelUtil { - private static final Logger log = Logger.getLogger(VespaModelUtil.class.getName()); + public static final ApplicationId ZONE_APPLICATION_ID = + ApplicationId.from("hosted-vespa", "routing", "default"); + public static final ClusterId ADMIN_CLUSTER_ID = new ClusterId("admin"); + public static final ClusterId NODE_ADMIN_CLUSTER_ID = new ClusterId("node-admin"); public static final ServiceType SLOBROK_SERVICE_TYPE = new ServiceType("slobrok"); public static final ServiceType CLUSTER_CONTROLLER_SERVICE_TYPE = new ServiceType("container-clustercontroller"); diff --git a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/ConcurrentSuspensionLimitForCluster.java b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/ConcurrentSuspensionLimitForCluster.java index 85750a9e9c8..215d9ed1b2d 100644 --- a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/ConcurrentSuspensionLimitForCluster.java +++ b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/ConcurrentSuspensionLimitForCluster.java @@ -7,6 +7,7 @@ package com.yahoo.vespa.orchestrator.policy; public enum ConcurrentSuspensionLimitForCluster { ONE_NODE(0), TEN_PERCENT(10), + TWENTY_PERCENT(20), ALL_NODES(100); int percentage; diff --git a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/HostedVespaClusterPolicy.java b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/HostedVespaClusterPolicy.java index 4f718af27b1..f45aa8c02ac 100644 --- a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/HostedVespaClusterPolicy.java +++ b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/HostedVespaClusterPolicy.java @@ -1,8 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.orchestrator.policy; -import com.yahoo.vespa.orchestrator.model.VespaModelUtil; import com.yahoo.vespa.orchestrator.model.ClusterApi; +import com.yahoo.vespa.orchestrator.model.VespaModelUtil; import static com.yahoo.vespa.orchestrator.policy.HostedVespaPolicy.ENOUGH_SERVICES_UP_CONSTRAINT; @@ -64,6 +64,10 @@ public class HostedVespaClusterPolicy implements ClusterPolicy { // Non-private for testing purposes ConcurrentSuspensionLimitForCluster getConcurrentSuspensionLimit(ClusterApi clusterApi) { + if (clusterApi.isStorageCluster()) { + return ConcurrentSuspensionLimitForCluster.ONE_NODE; + } + if (VespaModelUtil.ADMIN_CLUSTER_ID.equals(clusterApi.clusterId())) { if (VespaModelUtil.SLOBROK_SERVICE_TYPE.equals(clusterApi.serviceType())) { return ConcurrentSuspensionLimitForCluster.ONE_NODE; @@ -72,8 +76,9 @@ public class HostedVespaClusterPolicy implements ClusterPolicy { return ConcurrentSuspensionLimitForCluster.ALL_NODES; } - if (clusterApi.isStorageCluster()) { - return ConcurrentSuspensionLimitForCluster.ONE_NODE; + if (clusterApi.getApplication().applicationId().equals(VespaModelUtil.ZONE_APPLICATION_ID) && + clusterApi.clusterId().equals(VespaModelUtil.NODE_ADMIN_CLUSTER_ID)) { + return ConcurrentSuspensionLimitForCluster.TWENTY_PERCENT; } return ConcurrentSuspensionLimitForCluster.TEN_PERCENT; diff --git a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/HostedVespaPolicy.java b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/HostedVespaPolicy.java index 49a7739c839..1e9efa2e700 100644 --- a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/HostedVespaPolicy.java +++ b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/policy/HostedVespaPolicy.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.orchestrator.policy; +import com.yahoo.config.provision.ApplicationId; import com.yahoo.log.LogLevel; import com.yahoo.vespa.applicationmodel.ApplicationInstance; import com.yahoo.vespa.applicationmodel.HostName; @@ -22,7 +23,6 @@ import java.util.logging.Logger; */ public class HostedVespaPolicy implements Policy { - public static final String APPLICATION_SUSPENDED_CONSTRAINT = "application-suspended"; public static final String ENOUGH_SERVICES_UP_CONSTRAINT = "enough-services-up"; public static final String SET_NODE_STATE_CONSTRAINT = "controller-set-node-state"; @@ -81,8 +81,8 @@ public class HostedVespaPolicy implements Policy { throw new HostStateChangeDeniedException( applicationApi.getNodeGroup(), HostedVespaPolicy.APPLICATION_SUSPENDED_CONSTRAINT, - "Unable to test availability constraints as the application " + applicationApi.applicationInfo() - + " is allowed to be down"); + "Unable to test availability constraints as the application " + + applicationApi.applicationId() + " is allowed to be down"); } // Apply per-cluster policy diff --git a/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/model/ApplicationApiImplTest.java b/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/model/ApplicationApiImplTest.java index 3abffac3a9c..a712c1db3e8 100644 --- a/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/model/ApplicationApiImplTest.java +++ b/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/model/ApplicationApiImplTest.java @@ -19,10 +19,10 @@ public class ApplicationApiImplTest { final ModelTestUtils modelUtils = new ModelTestUtils(); @Test - public void testApplicationInfo() { + public void testApplicationId() { ApplicationApiImpl applicationApi = modelUtils.createApplicationApiImpl(modelUtils.createApplicationInstance(new ArrayList<>())); - assertEquals("tenant:application-name:foo:bar:default", applicationApi.applicationInfo()); + assertEquals("tenant:application-name:default", applicationApi.applicationId().serializedForm()); } @Test diff --git a/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/model/ClusterApiImplTest.java b/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/model/ClusterApiImplTest.java index d17ffd4452d..ad1ce647a7c 100644 --- a/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/model/ClusterApiImplTest.java +++ b/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/model/ClusterApiImplTest.java @@ -16,8 +16,10 @@ import java.util.Optional; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; public class ClusterApiImplTest { + final ApplicationApi applicationApi = mock(ApplicationApi.class); final ModelTestUtils modelUtils = new ModelTestUtils(); @Test @@ -42,6 +44,7 @@ public class ClusterApiImplTest { ); ClusterApiImpl clusterApi = new ClusterApiImpl( + applicationApi, serviceCluster, new NodeGroup(modelUtils.createApplicationInstance(new ArrayList<>()), hostName5), modelUtils.getHostStatusMap(), @@ -97,6 +100,7 @@ public class ClusterApiImplTest { boolean expectedNoServicesOutsideGroupIsDown, HostName... groupNodes) { ClusterApiImpl clusterApi = new ClusterApiImpl( + applicationApi, serviceCluster, new NodeGroup(modelUtils.createApplicationInstance(new ArrayList<>()), groupNodes), modelUtils.getHostStatusMap(), @@ -122,6 +126,7 @@ public class ClusterApiImplTest { ); ClusterApiImpl clusterApi = new ClusterApiImpl( + applicationApi, serviceCluster, new NodeGroup(modelUtils.createApplicationInstance(new ArrayList<>()), hostName1, hostName3), new HashMap<>(), diff --git a/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/policy/HostedVespaClusterPolicyTest.java b/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/policy/HostedVespaClusterPolicyTest.java index 0f56a620f1c..c316f79c3d2 100644 --- a/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/policy/HostedVespaClusterPolicyTest.java +++ b/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/policy/HostedVespaClusterPolicyTest.java @@ -1,15 +1,14 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.orchestrator.policy; -import com.yahoo.vespa.applicationmodel.ApplicationInstanceId; +import com.yahoo.config.provision.ApplicationId; import com.yahoo.vespa.applicationmodel.ClusterId; -import com.yahoo.vespa.applicationmodel.HostName; import com.yahoo.vespa.applicationmodel.ServiceType; -import com.yahoo.vespa.applicationmodel.TenantId; +import com.yahoo.vespa.orchestrator.model.ApplicationApi; import com.yahoo.vespa.orchestrator.model.ClusterApi; import com.yahoo.vespa.orchestrator.model.NodeGroup; import com.yahoo.vespa.orchestrator.model.VespaModelUtil; -import com.yahoo.vespa.orchestrator.status.MutableStatusRegistry; +import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -20,9 +19,15 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; public class HostedVespaClusterPolicyTest { + private ApplicationApi applicationApi = mock(ApplicationApi.class); private ClusterApi clusterApi = mock(ClusterApi.class); private HostedVespaClusterPolicy policy = spy(new HostedVespaClusterPolicy()); + @Before + public void setUp() { + when(clusterApi.getApplication()).thenReturn(applicationApi); + } + @Test public void testSlobrokSuspensionLimit() { when(clusterApi.clusterId()).thenReturn(VespaModelUtil.ADMIN_CLUSTER_ID); @@ -48,7 +53,17 @@ public class HostedVespaClusterPolicyTest { } @Test + public void testNodeAdminSuspensionLimit() { + when(applicationApi.applicationId()).thenReturn(VespaModelUtil.ZONE_APPLICATION_ID); + when(clusterApi.clusterId()).thenReturn(VespaModelUtil.NODE_ADMIN_CLUSTER_ID); + when(clusterApi.isStorageCluster()).thenReturn(false); + assertEquals(ConcurrentSuspensionLimitForCluster.TWENTY_PERCENT, + policy.getConcurrentSuspensionLimit(clusterApi)); + } + + @Test public void testDefaultSuspensionLimit() { + when(applicationApi.applicationId()).thenReturn(ApplicationId.fromSerializedForm("a:b:c")); when(clusterApi.clusterId()).thenReturn(new ClusterId("some-cluster-id")); when(clusterApi.isStorageCluster()).thenReturn(false); assertEquals(ConcurrentSuspensionLimitForCluster.TEN_PERCENT, diff --git a/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/policy/HostedVespaPolicyTest.java b/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/policy/HostedVespaPolicyTest.java index ec11ed47ba5..220371c4a17 100644 --- a/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/policy/HostedVespaPolicyTest.java +++ b/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/policy/HostedVespaPolicyTest.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.orchestrator.policy; +import com.yahoo.config.provision.ApplicationId; import com.yahoo.vespa.applicationmodel.HostName; import com.yahoo.vespa.orchestrator.OrchestrationException; import com.yahoo.vespa.orchestrator.controller.ClusterControllerClient; @@ -42,7 +43,7 @@ public class HostedVespaPolicyTest { final HostedVespaClusterPolicy clusterPolicy = mock(HostedVespaClusterPolicy.class); final HostedVespaPolicy policy = new HostedVespaPolicy(clusterPolicy, clientFactory); final ApplicationApi applicationApi = mock(ApplicationApi.class); - when(applicationApi.applicationInfo()).thenReturn("tenant:app"); + when(applicationApi.applicationId()).thenReturn(ApplicationId.fromSerializedForm("tenant:app:default")); ClusterApi clusterApi1 = mock(ClusterApi.class); ClusterApi clusterApi2 = mock(ClusterApi.class); @@ -93,7 +94,7 @@ public class HostedVespaPolicyTest { final HostedVespaClusterPolicy clusterPolicy = mock(HostedVespaClusterPolicy.class); final HostedVespaPolicy policy = new HostedVespaPolicy(clusterPolicy, clientFactory); final ApplicationApi applicationApi = mock(ApplicationApi.class); - when(applicationApi.applicationInfo()).thenReturn("tenant:app"); + when(applicationApi.applicationId()).thenReturn(ApplicationId.fromSerializedForm("tenant:app:default")); ClusterApi clusterApi1 = mock(ClusterApi.class); ClusterApi clusterApi2 = mock(ClusterApi.class); diff --git a/parent/pom.xml b/parent/pom.xml index 46b8c90baef..02282f8218d 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -172,7 +172,7 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-source-plugin</artifactId> - <version>2.1.2</version> + <version>2.4</version> <configuration> <includePom>true</includePom> </configuration> @@ -385,6 +385,11 @@ <artifactId>commons-exec</artifactId> <version>1.3</version> </dependency> + <dependency> + <groupId>org.apache.velocity</groupId> + <artifactId>velocity</artifactId> + <version>1.7</version> + </dependency> <dependency> <groupId>io.airlift</groupId> <artifactId>airline</artifactId> @@ -23,941 +23,6 @@ </developer> </developers> - <distributionManagement> - <repository> - <id>bintray-vespa-repo</id> - <url>https://api.bintray.com/maven/yahoo/maven/vespa;publish=1</url> - </repository> - </distributionManagement> - - <repositories> - <!-- Required for Athenz libraries --> - <repository> - <snapshots> - <enabled>false</enabled> - </snapshots> - <id>bintray-yahoo-maven</id> - <name>bintray</name> - <url>https://yahoo.bintray.com/maven</url> - </repository> - </repositories> - - <scm> - <connection>scm:git:git@github.com:vespa-engine/vespa.git</connection> - <developerConnection>scm:git:git@github.com:vespa-engine/vespa.git</developerConnection> - <url>git@github.com:vespa-engine/vespa.git</url> - </scm> - - <build> - <finalName>${project.artifactId}</finalName> - <extensions> - <extension> - <groupId>org.apache.maven.wagon</groupId> - <artifactId>wagon-ssh-external</artifactId> - <version>2.7</version> - </extension> - <extension> - <groupId>org.apache.maven.archetype</groupId> - <artifactId>archetype-packaging</artifactId> - <version>2.0</version> - </extension> - </extensions> - <pluginManagement> - <plugins> - <plugin> - <groupId>org.antlr</groupId> - <artifactId>antlr3-maven-plugin</artifactId> - <version>${antlr.version}</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-antrun-plugin</artifactId> - <version>1.7</version> - </plugin> - <plugin> - <groupId>org.apache.felix</groupId> - <artifactId>maven-bundle-plugin</artifactId> - <version>2.4.0</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-assembly-plugin</artifactId> - <version>2.4</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-compiler-plugin</artifactId> - <version>3.6.1</version> - <configuration> - <source>1.8</source> - <target>1.8</target> - <showWarnings>true</showWarnings> - <optimize>true</optimize> - <showDeprecation>false</showDeprecation> - <compilerArgs> - <arg>-Xlint:all</arg> - <arg>-Xlint:-serial</arg> - <arg>-Xlint:-try</arg> - <arg>-Xlint:-processing</arg> - <arg>-Xlint:-varargs</arg> - <arg>-Werror</arg> - </compilerArgs> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-dependency-plugin</artifactId> - <version>2.10</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-deploy-plugin</artifactId> - <version>2.5</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-install-plugin</artifactId> - <version>2.5.2</version> - <configuration> - <updateReleaseInfo>true</updateReleaseInfo> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-jar-plugin</artifactId> - <version>3.0.2</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-javadoc-plugin</artifactId> - <configuration> - <additionalparam>-Xdoclint:${doclint} -Xdoclint:-missing</additionalparam> - </configuration> - <version>2.10.4</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-plugin-plugin</artifactId> - <version>3.5</version> - <configuration> - <!-- see http://jira.codehaus.org/browse/MNG-5346 --> - <skipErrorNoDescriptorsFound>true</skipErrorNoDescriptorsFound> - </configuration> - <executions> - <execution> - <id>mojo-descriptor</id> - <goals> - <goal>descriptor</goal> - </goals> - </execution> - </executions> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-resources-plugin</artifactId> - <version>2.7</version> - <configuration> - <escapeString>\</escapeString> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-site-plugin</artifactId> - <version>3.3</version> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-source-plugin</artifactId> - <version>2.1.2</version> - <configuration> - <includePom>true</includePom> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-surefire-plugin</artifactId> - <version>${surefire.version}</version> - <configuration> - <redirectTestOutputToFile>${test.hide}</redirectTestOutputToFile> - <systemPropertyVariables> - <java.io.tmpdir>${project.build.directory}</java.io.tmpdir> - </systemPropertyVariables> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-surefire-report-plugin</artifactId> - <version>${surefire.version}</version> - <configuration> - <alwaysGenerateSurefireReport>false</alwaysGenerateSurefireReport> - <showSuccess>false</showSuccess> - </configuration> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>build-helper-maven-plugin</artifactId> - <version>1.9.1</version> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>exec-maven-plugin</artifactId> - <version>1.6.0</version> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>javacc-maven-plugin</artifactId> - <version>2.6</version> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>properties-maven-plugin</artifactId> - <version>1.0.0</version> - </plugin> - <plugin> - <groupId>net.alchim31.maven</groupId> - <artifactId>scala-maven-plugin</artifactId> - <version>3.2.2</version> - <configuration> - <args> - <arg>-unchecked</arg> - <arg>-deprecation</arg> - <arg>-feature</arg> - <arg>-Xfatal-warnings</arg> - </args> - </configuration> - </plugin> - <plugin> - <groupId>com.yahoo.vespa</groupId> - <artifactId>bundle-plugin</artifactId> - <version>${project.version}</version> - <configuration> - <configGenVersion>${project.version}</configGenVersion> - <useCommonAssemblyIds>true</useCommonAssemblyIds> - </configuration> - </plugin> - </plugins> - </pluginManagement> - </build> - <profiles> - <profile> - <id>attach-sources</id> - <activation> - <property> - <name>!skipSources</name> - </property> - </activation> - <build> - <plugins> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-source-plugin</artifactId> - <executions> - <execution> - <id>attach-sources</id> - <goals> - <goal>jar-no-fork</goal> - </goals> - </execution> - </executions> - </plugin> - </plugins> - </build> - </profile> - <profile> - <id>generate-javadoc</id> - <activation> - <property> - <name>!skipJavadoc</name> - </property> - </activation> - <build> - <plugins> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-javadoc-plugin</artifactId> - <executions> - <execution> - <id>generate-javadoc</id> - <phase>package</phase> - <goals> - <goal>javadoc</goal> - </goals> - </execution> - </executions> - <configuration> - <additionalparam>-Xdoclint:${doclint} -Xdoclint:-missing</additionalparam> - <failOnError>${javadoc.failOnError}</failOnError> - <quiet>true</quiet> - <show>private</show> - </configuration> - </plugin> - </plugins> - </build> - </profile> - <profile> - <id>coverage</id> - <build> - <plugins> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>exec-maven-plugin</artifactId> - <configuration> - <includePluginDependencies>true</includePluginDependencies> - </configuration> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>build-helper-maven-plugin</artifactId> - <executions> - <execution> - <phase>generate-sources</phase> - <goals> - <goal>add-source</goal> - </goals> - <configuration> - <sources> - <source>src/main/scala</source> - </sources> - </configuration> - </execution> - <execution> - <id>add-test-source</id> - <phase>generate-test-sources</phase> - <goals> - <goal>add-test-source</goal> - </goals> - <configuration> - <sources> - <source>src/test/scala</source> - </sources> - </configuration> - </execution> - </executions> - </plugin> - </plugins> - </build> - </profile> - <profile> - <id>sign-artifacts</id> - <build> - <plugins> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-gpg-plugin</artifactId> - <version>1.6</version> - <executions> - <execution> - <id>sign-artifacts</id> - <phase>verify</phase> - <goals> - <goal>sign</goal> - </goals> - </execution> - </executions> - </plugin> - </plugins> - </build> - </profile> - </profiles> - <dependencyManagement> - <dependencies> - <dependency> - <groupId>org.apache.maven.wagon</groupId> - <artifactId>wagon-ssh-external</artifactId> - <version>2.7</version> - </dependency> - <dependency> - <groupId>com.github.cverges.expect4j</groupId> - <artifactId>expect4j</artifactId> - <version>1.6</version> - </dependency> - <dependency> - <groupId>org.apache.commons</groupId> - <artifactId>commons-compress</artifactId> - <version>1.11</version> - </dependency> - <dependency> - <groupId>org.apache.commons</groupId> - <artifactId>commons-exec</artifactId> - <version>1.3</version> - </dependency> - <dependency> - <groupId>io.airlift</groupId> - <artifactId>airline</artifactId> - <version>0.7</version> - </dependency> - <dependency> - <groupId>aopalliance</groupId> - <artifactId>aopalliance</artifactId> - <version>1.0</version> - </dependency> - <dependency> - <groupId>org.ow2.asm</groupId> - <artifactId>asm</artifactId> - <version>5.2</version> - </dependency> - <dependency> - <groupId>com.google.code.findbugs</groupId> - <artifactId>annotations</artifactId> - <version>1.3.9</version> - </dependency> - <dependency> - <groupId>com.google.code.findbugs</groupId> - <artifactId>jsr305</artifactId> - <version>1.3.9</version> - </dependency> - <dependency> - <groupId>com.google.guava</groupId> - <artifactId>guava</artifactId> - <version>18.0</version> - </dependency> - <dependency> - <groupId>com.google.guava</groupId> - <artifactId>guava-testlib</artifactId> - <version>18.0</version> - </dependency> - <dependency> - <groupId>com.google.inject</groupId> - <artifactId>guice</artifactId> - <version>3.0</version> - </dependency> - <dependency> - <groupId>com.google.inject</groupId> - <artifactId>guice</artifactId> - <version>3.0</version> - <classifier>no_aop</classifier> - </dependency> - <dependency> - <groupId>com.google.inject.extensions</groupId> - <artifactId>guice-assistedinject</artifactId> - <version>3.0</version> - </dependency> - <dependency> - <groupId>com.google.inject.extensions</groupId> - <artifactId>guice-multibindings</artifactId> - <version>3.0</version> - </dependency> - <dependency> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - <version>3.4.0</version> - </dependency> - <dependency> - <groupId>com.googlecode.jmockit</groupId> - <artifactId>jmockit</artifactId> - <version>1.2</version> - </dependency> - <dependency> - <groupId>com.goldmansachs</groupId> - <artifactId>gs-collections</artifactId> - <version>6.1.0</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-core</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-databind</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-annotations</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.jaxrs</groupId> - <artifactId>jackson-jaxrs-json-provider</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.module</groupId> - <artifactId>jackson-module-jaxb-annotations</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.jaxrs</groupId> - <artifactId>jackson-jaxrs-base</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.jaxrs</groupId> - <artifactId>jackson-jaxrs-xml-provider</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.dataformat</groupId> - <artifactId>jackson-dataformat-xml</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.datatype</groupId> - <artifactId>jackson-datatype-jdk8</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.fasterxml.jackson.datatype</groupId> - <artifactId>jackson-datatype-jsr310</artifactId> - <version>${jackson2.version}</version> - </dependency> - <dependency> - <groupId>com.infradna.tool</groupId> - <artifactId>bridge-method-annotation</artifactId> - <version>1.4</version> - </dependency> - <dependency> - <groupId>commons-cli</groupId> - <artifactId>commons-cli</artifactId> - <version>1.3.1</version> - </dependency> - <dependency> - <groupId>commons-codec</groupId> - <artifactId>commons-codec</artifactId> - <version>1.4</version> - </dependency> - <dependency> - <groupId>commons-collections</groupId> - <artifactId>commons-collections</artifactId> - <version>3.2.1</version> - </dependency> - <dependency> - <groupId>commons-configuration</groupId> - <artifactId>commons-configuration</artifactId> - <version>1.6</version> - </dependency> - <dependency> - <groupId>commons-daemon</groupId> - <artifactId>commons-daemon</artifactId> - <version>1.0.3</version> - </dependency> - <dependency> - <groupId>commons-io</groupId> - <artifactId>commons-io</artifactId> - <version>2.4</version> - </dependency> - <dependency> - <groupId>commons-lang</groupId> - <artifactId>commons-lang</artifactId> - <version>${commons-lang.version}</version> - </dependency> - <dependency> - <!-- This version is exported by jdisc via jcl-over-slf4j. --> - <groupId>commons-logging</groupId> - <artifactId>commons-logging</artifactId> - <version>1.1.1</version> - </dependency> - <dependency> - <groupId>commons-net</groupId> - <artifactId>commons-net</artifactId> - <version>2.0</version> - </dependency> - <dependency> - <groupId>commons-pool</groupId> - <artifactId>commons-pool</artifactId> - <version>1.5.6</version> - </dependency> - <!-- Explicitly included to get Zookeeper version 3.4.10, - can be excluded if you want the Zookeeper version - used by curator by default - --> - <dependency> - <groupId>org.apache.zookeeper</groupId> - <artifactId>zookeeper</artifactId> - <version>3.4.10</version> - </dependency> - <dependency> - <groupId>org.apache.curator</groupId> - <artifactId>curator-recipes</artifactId> - <version>${curator.version}</version> - </dependency> - <dependency> - <groupId>org.apache.curator</groupId> - <artifactId>curator-test</artifactId> - <version>${curator.version}</version> - </dependency> - <dependency> - <groupId>javax.servlet</groupId> - <artifactId>javax.servlet-api</artifactId> - <version>3.1.0</version> - </dependency> - <dependency> - <groupId>junit</groupId> - <artifactId>junit</artifactId> - <version>4.12</version> - </dependency> - <dependency> - <groupId>org.antlr</groupId> - <artifactId>antlr-runtime</artifactId> - <version>${antlr.version}</version> - </dependency> - <dependency> - <groupId>org.antlr</groupId> - <artifactId>antlr4-runtime</artifactId> - <version>${antlr4.version}</version> - </dependency> - <dependency> - <groupId>org.apache.aries.spifly</groupId> - <artifactId>org.apache.aries.spifly.dynamic.bundle</artifactId> - <version>${aries.spifly.version}</version> - </dependency> - <dependency> - <groupId>org.apache.commons</groupId> - <artifactId>commons-lang3</artifactId> - <version>3.1</version> - </dependency> - <dependency> - <groupId>org.apache.felix</groupId> - <artifactId>org.apache.felix.framework</artifactId> - <version>4.2.1</version> - </dependency> - <dependency> - <groupId>org.apache.felix</groupId> - <artifactId>org.apache.felix.log</artifactId> - <version>1.0.1</version> - </dependency> - <dependency> - <groupId>org.apache.felix</groupId> - <artifactId>org.apache.felix.main</artifactId> - <version>4.2.1</version> - </dependency> - <dependency> - <groupId>org.apache.httpcomponents</groupId> - <artifactId>fluent-hc</artifactId> - <version>4.3.6</version> - </dependency> - <dependency> - <groupId>org.apache.httpcomponents</groupId> - <artifactId>httpclient</artifactId> - <version>4.3.6</version> - </dependency> - <dependency> - <groupId>org.apache.httpcomponents</groupId> - <artifactId>httpcore</artifactId> - <version>4.3.3</version> - </dependency> - <dependency> - <groupId>org.apache.httpcomponents</groupId> - <artifactId>httpmime</artifactId> - <version>4.3.6</version> - </dependency> - <dependency> - <groupId>org.apache.maven</groupId> - <artifactId>maven-artifact</artifactId> - <version>3.5.0</version> - </dependency> - <dependency> - <groupId>org.apache.maven</groupId> - <artifactId>maven-core</artifactId> - <version>3.5.0</version> - </dependency> - <dependency> - <groupId>org.apache.maven</groupId> - <artifactId>maven-model</artifactId> - <version>3.5.0</version> - </dependency> - <dependency> - <groupId>org.apache.maven.plugin-tools</groupId> - <artifactId>maven-plugin-annotations</artifactId> - <version>3.5</version> - </dependency> - <dependency> - <groupId>org.apache.maven</groupId> - <artifactId>maven-plugin-api</artifactId> - <version>3.5.0</version> - </dependency> - <dependency> - <groupId>org.apache.maven</groupId> - <artifactId>maven-project</artifactId> - <version>2.2.1</version> - </dependency> - <dependency> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-jar-plugin</artifactId> - <version>3.0.2</version> - </dependency> - <dependency> - <groupId>org.apache.maven.surefire</groupId> - <artifactId>surefire-junit4</artifactId> - <version>${surefire.version}</version> - </dependency> - <dependency> - <groupId>org.apache.maven.surefire</groupId> - <artifactId>surefire-providers</artifactId> - <version>${surefire.version}</version> - <type>pom</type> - </dependency> - <dependency> - <groupId>org.codehaus.jettison</groupId> - <artifactId>jettison</artifactId> - <version>1.3.1</version> - </dependency> - <dependency> - <groupId>org.cthul</groupId> - <artifactId>cthul-matchers</artifactId> - <version>1.0</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-continuation</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-server</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-servlet</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-servlets</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-util</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-http</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.eclipse.jetty</groupId> - <artifactId>jetty-jmx</artifactId> - <version>${jetty.version}</version> - </dependency> - <dependency> - <groupId>org.hamcrest</groupId> - <artifactId>hamcrest-all</artifactId> - <version>1.3</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.hamcrest</groupId> - <artifactId>hamcrest-core</artifactId> - <version>1.3</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.hamcrest</groupId> - <artifactId>hamcrest-library</artifactId> - <version>1.3</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>uk.co.datumedge</groupId> - <artifactId>hamcrest-json</artifactId> - <version>0.2</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.hdrhistogram</groupId> - <artifactId>HdrHistogram</artifactId> - <version>2.1.8</version> - </dependency> - <dependency> - <groupId>org.json</groupId> - <artifactId>json</artifactId> - <version>20090211</version> - </dependency> - <dependency> - <groupId>org.mockito</groupId> - <artifactId>mockito-all</artifactId> - <version>1.9.5</version> - </dependency> - <dependency> - <groupId>org.mockito</groupId> - <artifactId>mockito-core</artifactId> - <version>1.9.5</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.osgi</groupId> - <artifactId>org.osgi.compendium</artifactId> - <version>4.3.0</version> - </dependency> - <dependency> - <groupId>org.osgi</groupId> - <artifactId>org.osgi.core</artifactId> - <version>4.3.0</version> - </dependency> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scala-library</artifactId> - <version>${scala.version}</version> - </dependency> - <dependency> - <groupId>org.scala-lang.modules</groupId> - <artifactId>scala-parser-combinators_${scala.major-version}</artifactId> - <version>1.0.1</version> - </dependency> - <dependency> - <groupId>org.scala-lang.modules</groupId> - <artifactId>scala-xml_${scala.major-version}</artifactId> - <version>1.0.2</version> - </dependency> - <dependency> - <groupId>org.scalatest</groupId> - <artifactId>scalatest_${scala.major-version}</artifactId> - <version>2.2.2</version> - </dependency> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>jcl-over-slf4j</artifactId> - <version>1.7.5</version> - </dependency> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>log4j-over-slf4j</artifactId> - <version>1.7.5</version> - </dependency> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>slf4j-api</artifactId> - <version>1.7.5</version> - </dependency> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>slf4j-jdk14</artifactId> - <version>1.7.5</version> - </dependency> - <dependency> - <groupId>org.springframework</groupId> - <artifactId>spring-test</artifactId> - <version>4.0.6.RELEASE</version> - </dependency> - <dependency> - <groupId>org.testng</groupId> - <artifactId>testng</artifactId> - <version>6.10</version> - </dependency> - <dependency> - <groupId>org.twdata.maven</groupId> - <artifactId>mojo-executor</artifactId> - <version>2.3.0</version> - </dependency> - <dependency> - <groupId>net.jcip</groupId> - <artifactId>jcip-annotations</artifactId> - <version>1.0</version> - </dependency> - <dependency> - <groupId>net.jpountz.lz4</groupId> - <artifactId>lz4</artifactId> - <version>1.3.0</version> - </dependency> - <dependency> - <groupId>net.spy</groupId> - <artifactId>spymemcached</artifactId> - <version>2.10.1</version> - </dependency> - <dependency> - <groupId>xerces</groupId> - <artifactId>xercesImpl</artifactId> - <version>2.11.0</version> - </dependency> - <dependency> - <groupId>org.bouncycastle</groupId> - <artifactId>bcpkix-jdk15on</artifactId> - <version>${bouncycastle.version}</version> - </dependency> - <dependency> - <groupId>org.bouncycastle</groupId> - <artifactId>bcprov-jdk15on</artifactId> - <version>${bouncycastle.version}</version> - </dependency> - <!-- jersey 2 support --> - <dependency> - <groupId>javax.ws.rs</groupId> - <artifactId>javax.ws.rs-api</artifactId> - <version>${javax.ws.rs-api.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.containers</groupId> - <artifactId>jersey-container-servlet-core</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.containers</groupId> - <artifactId>jersey-container-servlet</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.media</groupId> - <artifactId>jersey-media-json-jackson</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.media</groupId> - <artifactId>jersey-media-multipart</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.ext</groupId> - <artifactId>jersey-proxy-client</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>org.glassfish.jersey.core</groupId> - <artifactId>jersey-client</artifactId> - <version>${jersey2.version}</version> - </dependency> - <dependency> - <groupId>com.ibm.icu</groupId> - <artifactId>icu4j</artifactId> - <version>57.1</version> - </dependency> - <dependency> - <groupId>com.yahoo.athenz</groupId> - <artifactId>athenz-zms-java-client</artifactId> - <version>${athenz.version}</version> - </dependency> - <dependency> - <groupId>com.yahoo.athenz</groupId> - <artifactId>athenz-zts-java-client</artifactId> - <version>${athenz.version}</version> - </dependency> - </dependencies> - </dependencyManagement> - - <properties> - <javax.ws.rs-api.version>2.0.1</javax.ws.rs-api.version> <!-- must be kept in sync with version used by current jersey2.version --> - <antlr.version>3.5.2</antlr.version> - <antlr4.version>4.5</antlr4.version> - <aries.spifly.version>1.0.8</aries.spifly.version> - <aries.util.version>1.0.0</aries.util.version> - <asm-debug-all.version>5.0.3</asm-debug-all.version> - <!-- Athenz dependencies. Make sure these dependencies matches those in Vespa's internal repositories --> - <athenz.version>1.7.28</athenz.version> - <bouncycastle.version>1.58</bouncycastle.version> - <commons-lang.version>2.6</commons-lang.version> - <!-- WARNING: If you change curator version, you also need to update - zkfacade/src/main/java/org/apache/curator/**/package-info.java - using something like - find zkfacade/src/main/java/org/apache/curator -name package-info.java | \ - xargs perl -pi -e 's/major = [0-9]+, minor = [0-9]+, micro = [0-9]+/major = 2, minor = 9, micro = 1/g' - --> - <curator.version>2.9.1</curator.version> - <jackson2.version>2.8.3</jackson2.version> - <jersey2.version>2.23.2</jersey2.version> - <jetty.version>9.4.6.v20170531</jetty.version> - <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> - <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding> - <test.hide>true</test.hide> - <doclint>all</doclint> - <scala.major-version>2.11</scala.major-version> - <scala.version>${scala.major-version}.4</scala.version> - <surefire.version>2.19.1</surefire.version> <!-- NOTE bjorncs 15.06.2017: Version 2.20 has OoM issues --> - </properties> - <modules> <module>application</module> <module>application-deploy-plugin</module> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java new file mode 100644 index 00000000000..5f0c016881a --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java @@ -0,0 +1,132 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.TensorProto; +import org.tensorflow.framework.TensorShapeProto; + +/** + * @author lesters + */ +public class AttrValueConverter { + + public static Tensor toVespaTensor(NodeDef tfNode, String attr) { + if (!tfNode.getAttrMap().containsKey(attr)) { + throw new IllegalArgumentException(tfNode.getName() + " has no attribute called " + attr); + } + AttrValue attrValue = tfNode.getAttrMap().get(attr); + switch (attrValue.getValueCase()) { + case TENSOR: + return buildFromTensor(attrValue); + case B: + return buildFromSingleValue(attrValue.getB() ? 1.0 : 0.0); + case F: + return buildFromSingleValue(attrValue.getF()); + case I: + return buildFromSingleValue(attrValue.getI()); + } + + throw new IllegalArgumentException(tfNode.getName() + + ": unsupported attribute type: '" + attrValue.getValueCase().toString() + "'"); + } + + private static Tensor buildFromSingleValue(double value) { + TensorType type = new TensorType.Builder().build(); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); + builder.cellByDirectIndex(0, value); + return builder.build(); + } + + private static Tensor buildFromTensor(AttrValue attrValue) { + TensorProto tensorProto = attrValue.getTensor(); + TensorType type = toVespaTensorType(tensorProto.getTensorShape()); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); + Values values = valuesOf(tensorProto); + for (int i = 0; i < values.size(); ++i) { + builder.cellByDirectIndex(i, values.get(i)); + } + Tensor tensor = builder.build(); + return tensor; + } + + private static Values valuesOf(TensorProto tensorProto) { + switch (tensorProto.getDtype()) { + case DT_BOOL: + return new BoolValues(tensorProto); + case DT_HALF: + return new HalfValues(tensorProto); + case DT_INT16: + case DT_INT32: + return new IntValues(tensorProto); + case DT_INT64: + return new Int64Values(tensorProto); + case DT_FLOAT: + return new FloatValues(tensorProto); + case DT_DOUBLE: + return new DoubleValues(tensorProto); + } + + throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); + } + + public static TensorType toVespaTensorType(TensorShapeProto shapeProto) { + TensorType.Builder b = new TensorType.Builder(); + for (TensorShapeProto.Dim dimension : shapeProto.getDimList()) { + int dimensionSize = (int)dimension.getSize(); + if (dimensionSize >= 0) + b.indexed("d" + b.rank(), dimensionSize); + else + b.indexed("d" + b.rank()); // unbound size + } + return b.build(); + } + + private static abstract class Values { + protected final TensorProto tensorProto; + protected Values(TensorProto tensorProto) { this.tensorProto = tensorProto; } + abstract double get(int i); + abstract int size(); + } + + private static class BoolValues extends Values { + BoolValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; } + @Override int size() { return tensorProto.getBoolValCount(); } + } + + private static class HalfValues extends Values { + HalfValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getHalfVal(i); } + @Override int size() { return tensorProto.getHalfValCount(); } + } + + private static class IntValues extends Values { + IntValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getIntVal(i); } + @Override int size() { return tensorProto.getIntValCount(); } + } + + private static class Int64Values extends Values { + Int64Values(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getInt64Val(i); } + @Override int size() { return tensorProto.getInt64ValCount(); } + } + + private static class FloatValues extends Values { + FloatValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getFloatVal(i); } + @Override int size() { return tensorProto.getFloatValCount(); } + } + + private static class DoubleValues extends Values { + DoubleValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getDoubleVal(i); } + @Override int size() { return tensorProto.getDoubleValCount(); } + } + + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 85452d16a77..816ef38e128 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -24,157 +24,318 @@ import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.SavedModelBundle; import org.tensorflow.Session; import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; +import java.util.function.Function; import java.util.stream.Collectors; -import java.util.stream.StreamSupport; +import java.util.stream.Stream; /** * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions. * * @author bratseth + * @author lesters */ class OperationMapper { + // A note on conversion from implicitly numbered to explicitly named dimensions: + // + // Vespa tensor dimensions are explicitly named and thus have an explicit notion of being + // 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation + // comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation + // around dimension renaming operations which mirrors those built into the TF operation definitions. + // + // To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' + // dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation + // and the result is then renamed again (if necessary) to recover this convention across a full nested + // computation. + // + // This requires us to track tensor types throughout the conversion. + + + // Supported TensorFlow operations + enum Operation { + + // TODO: move the implementations to specific files as we support more operations + + /* + * array ops + */ + CONST (OperationMapper::constant), + EXPANDDIMS (OperationMapper::expandDims), + IDENTITY (OperationMapper::identity), + PLACEHOLDER (OperationMapper::placeholder), + PLACEHOLDERWITHDEFAULT (OperationMapper::placeholderWithDefault), + RESHAPE (OperationMapper::reshape), + SQUEEZE (OperationMapper::squeeze), + + /* + * control flow + */ + MERGE (OperationMapper::merge), + SWITCH (OperationMapper::switchOp), + + /* + * math ops + */ + ADD (OperationMapper::add), + ADD_N (OperationMapper::add), + ACOS (OperationMapper::acos), + DIV (OperationMapper::div), + REALDIV (OperationMapper::div), + FLOOR (OperationMapper::floor), + MATMUL (OperationMapper::matmul), + MAXIMUM (OperationMapper::maximum), + MEAN (OperationMapper::mean), + REDUCEMEAN (OperationMapper::mean), + MUL (OperationMapper::mul), + MULTIPLY (OperationMapper::mul), + RSQRT (OperationMapper::rsqrt), + SELECT (OperationMapper::select), + WHERE3 (OperationMapper::select), + SIGMOID (OperationMapper::sigmoid), + SQUAREDDIFFERENCE (OperationMapper::squaredDifference), + SUB (OperationMapper::sub), + SUBTRACT (OperationMapper::sub), + + /* + * nn ops + */ + BIASADD (OperationMapper::add), + ELU (OperationMapper::elu), + RELU (OperationMapper::relu), + SELU (OperationMapper::selu), + SOFTMAX (OperationMapper::softMax), + + /* + * state ops + */ + VARIABLE (OperationMapper::variable), + VARIABLEV2 (OperationMapper::variable), + + /* + * evaluation no-ops + */ + STOPGRADIENT (OperationMapper::identity), + NOOP (OperationMapper::noOp); + + + private final Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func; + + Operation(Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func) { + this.func = func; + } + + Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) { + return func.apply(params); + } + + } + + static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) { + Optional<Operation> operation = Stream.of(Operation.values()) + .filter(op -> op.name().equalsIgnoreCase(params.node().getOp())) + .findFirst(); + if (operation.isPresent()) { + return operation.get().map(params); + } + params.signature().importWarning("TensorFlow operation '" + params.node().getOp() + + "' in node '" + params.node().getName() + "' is not supported."); + return Optional.empty(); + } + + /* - A note on conversion from implicitly numbered to explicitly named dimensions: - Vespa tensor dimensions are explicitly named and thus have an explicit notion of being - 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation - comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation - around dimension renaming operations which mirrors those built into the TF operation definitions. - - To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' - dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation - and the result is then renamed again (if necessary) to recover this convention across a full nested - computation. - - This requires us to track tensor types throughout the conversion. + * Operations */ - private TensorConverter tensorConverter = new TensorConverter(); + private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) { + Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value"); + return createConstant(params, value); + } - TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) { - ensureArguments(2, arguments, "join"); - TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(1); + private static Optional<TypedTensorFunction> expandDims(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 2)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); - if (a.type().rank() == 0 && b.type().rank() > 0) { - return new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction)); + Tensor axis = getConstantTensor(params, params.node().getInput(1)); + if (axis.type().rank() != 0) { + throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar"); } - if (b.type().rank() == 0 && a.type().rank() > 0) { - return new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction)); + + TensorFunction inputFunction = inputs.get(0).get().function(); + TensorType inputType = inputs.get(0).get().type(); + + int dimensionToInsert = (int)axis.asDouble(); + if (dimensionToInsert < 0) { + dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; } - if (a.type().rank() == b.type().rank()) { - return new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction)); + + TensorType.Builder outputTypeBuilder = new TensorType.Builder(); + int dimensionIndex = 0; + for (int i = 0; i < inputType.dimensions().size() + 1; ++i) { + String name = String.format("temp_%d", i); + Long size; + if (i == dimensionToInsert) { + size = 1L; + } else { + size = dimensionSize(inputType.dimensions().get(dimensionIndex)); + dimensionIndex++; + } + outputTypeBuilder.indexed(name, size); } - // Well now we have entered the wonderful world of "broadcasting" - // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html - // I'm not able to extract from that any unambiguous specification of which dimensions - // should be "stretched" when the tensor do not have the same number of dimensions. - // From trying this with TensorFlow it appears that the second tensor is matched to the - // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. - // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). + return reshape(inputFunction, inputType, outputTypeBuilder.build()); + } - if (a.type().rank() > b.type().rank()) { - TensorFunction renameFunction = renameForBroadcast(a, b); - return new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction)); + private static Optional<TypedTensorFunction> identity(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 1)) { + return Optional.empty(); } - TensorFunction renameFunction = renameForBroadcast(b, a); - return new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction)); + return params.inputs().get(0); } - private TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) { - List<String> renameFrom = new ArrayList<>(); - List<String> renameTo = new ArrayList<>(); - int sizeDifference = a.type().rank() - b.type().rank(); - for (int i = 0; i < b.type().rank(); i++) { - renameFrom.add(b.type().dimensions().get(i).name()); - renameTo.add("d" + (sizeDifference + i)); + private static Optional<TypedTensorFunction> placeholder(TensorFlowImporter.Parameters params) { + String name = params.node().getName(); + TensorType type = params.result().arguments().get(name); + if (type == null) { + throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + + "', but there is no such placeholder"); } - return new Rename(b.function(), renameFrom, renameTo); + // Included literally in the expression and so must be produced by a separate macro in the rank profile + TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name)); + return Optional.of(output); + } + + private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) { + String name = params.node().getInput(0); + Tensor defaultValue = getConstantTensor(params, name); + params.result().constant(name, defaultValue); + params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")"))); + // The default value will be provided by the macro. Users can override macro to change value. + TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name)); + return Optional.of(output); } - TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) { - ensureArguments(1, arguments, "apply"); - TypedTensorFunction a = arguments.get(0); + private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 2)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + Tensor shape = getConstantTensor(params, params.node().getInput(1)); - TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); - com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); - return new TypedTensorFunction(resultType, function); + TensorFunction inputFunction = inputs.get(0).get().function(); + TensorType inputType = inputs.get(0).get().type(); + + TensorType.Builder outputTypeBuilder = new TensorType.Builder(); + int dimensionIndex = 0; + for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { + Tensor.Cell cell = cellIterator.next(); + int size = cell.getValue().intValue(); + if (size < 0) { + size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue(); + } + outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size); + dimensionIndex++; + } + return reshape(inputFunction, inputType, outputTypeBuilder.build()); } - TypedTensorFunction placeholder(NodeDef tfNode, TensorFlowModel result) { - String name = tfNode.getName(); - TensorType type = result.arguments().get(name); - if (type == null) - throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + - "', but there is no such placeholder"); - // Included literally in the expression and so must be produced by a separate macro in the rank profile - return new TypedTensorFunction(type, new VariableTensor(name)); + private static Optional<TypedTensorFunction> squeeze(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 1)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + + TensorFunction inputFunction = inputs.get(0).get().function(); + TensorType inputType = inputs.get(0).get().type(); + List<String> squeezeDimensions; + + AttrValue squeezeDimsAttr = params.node().getAttrMap().get("squeeze_dims"); + if (squeezeDimsAttr == null) { + squeezeDimensions = inputType.dimensions().stream(). + filter(dim -> dimensionSize(dim) == 1). + map(TensorType.Dimension::name). + collect(Collectors.toList()); + } else { + squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). + map(i -> i < 0 ? inputType.dimensions().size() - i : i). + map(i -> inputType.dimensions().get(i.intValue())). + filter(dim -> dimensionSize(dim) == 1). + map(TensorType.Dimension::name). + collect(Collectors.toList()); + } + + if (squeezeDimensions.isEmpty()) { + return inputs.get(0); + } + + TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); + TensorType outputType = Reduce.outputType(inputType, squeezeDimensions); + TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); + return Optional.of(output); } - TypedTensorFunction placeholderWithDefault(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) { - String name = tfNode.getInput(0); - Tensor defaultValue = getConstantTensor(model, name); - result.constant(name, defaultValue); - result.macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")"))); - // The default value will be provided by the macro. Users can override macro to change value. - return new TypedTensorFunction(defaultValue.type(), new VariableTensor(name)); + private static Optional<TypedTensorFunction> merge(TensorFlowImporter.Parameters params) { + return params.inputs().stream() + .filter(Optional::isPresent) + .findFirst() + .orElse(Optional.empty()); } - TypedTensorFunction constant(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) { - String name = tfNode.getName(); - if (tfNode.getInputList().size() != 0) { - throw new IllegalArgumentException("A constant node must have zero inputs but '" + name + "' has " + - tfNode.getInputList().size()); + private static Optional<TypedTensorFunction> switchOp(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 2)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + Tensor predicate = getConstantTensor(params, params.node().getInput(1)); + if (predicate.type().rank() != 0) { + throw new IllegalArgumentException("'switch': predicate must be a scalar"); + } + double pred = predicate.asDouble(); + int output = params.port().length() > 0 ? Integer.parseInt(params.port()) : 0; + if (output < 0 || output > 1) { + throw new IllegalArgumentException("'switch': predicate is not boolean"); + } + if (pred == output) { + return inputs.get(0); } - return importConstantTensor(tfNode, model, result, name); + return Optional.empty(); } - TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) { - if ( ! tfNode.getName().endsWith("/read")) - throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " + - "nodes are only supported when reading variables"); - if (tfNode.getInputList().size() != 1) - throw new IllegalArgumentException("A Variable/read node must have one input but '" + - tfNode.getName() + "' has " + tfNode.getInputList().size()); + private static Optional<TypedTensorFunction> add(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.add()); + } - String name = tfNode.getInput(0); - return importConstantTensor(tfNode, model, result, name); + private static Optional<TypedTensorFunction> acos(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.acos()); } - private TypedTensorFunction importConstantTensor(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result, String name) { - AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); - if (shapes == null) - throw new IllegalArgumentException("'" + name + "' is missing a tensor shape"); - Tensor constant = getConstantTensor(model, name); - result.constant(name, constant); - return new TypedTensorFunction(constant.type(), - new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(\"" + name + "\")"))); + private static Optional<TypedTensorFunction> div(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.divide()); } - private Tensor getConstantTensor(SavedModelBundle model, String name) { - Session.Runner fetched = model.session().runner().fetch(name); - List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); - if (importedTensors.size() != 1) - throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " + - importedTensors.size()); - return tensorConverter.toVespaTensor(importedTensors.get(0)); + private static Optional<TypedTensorFunction> floor(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.floor()); } - TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { - ensureArguments(2, arguments, "matmul"); - TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(1); + private static Optional<TypedTensorFunction> matmul(TensorFlowImporter.Parameters params) { + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + if (!checkInputs(params, 2)) { + return Optional.empty(); + } + + TypedTensorFunction a = inputs.get(0).get(); + TypedTensorFunction b = inputs.get(1).get(); if (a.type().rank() < 2 || b.type().rank() < 2) throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); if (a.type().rank() != b.type().rank()) @@ -190,17 +351,24 @@ class OperationMapper { Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"), ImmutableList.of("d1", afterLastDim)); Matmul matmul = new Matmul(a.function(), renamedB, "d1"); - return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"), + TypedTensorFunction output = new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"), new Rename(matmul, afterLastDim, "d1")); + return Optional.of(output); } - TypedTensorFunction mean(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) { - ensureArguments(2, arguments, "mean"); - Tensor reductionIndices = getConstantTensor(model, tfNode.getInput(1)); + private static Optional<TypedTensorFunction> maximum(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.max()); + } - TensorFunction inputFunction = arguments.get(0).function(); - TensorType inputType = arguments.get(0).type(); + private static Optional<TypedTensorFunction> mean(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 2)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + TensorFunction inputFunction = inputs.get(0).get().function(); + TensorType inputType = inputs.get(0).get().type(); + Tensor reductionIndices = getConstantTensor(params, params.node().getInput(1)); List<String> reduceDimensions = new ArrayList<>(); for (Iterator<Tensor.Cell> cellIterator = reductionIndices.cellIterator(); cellIterator.hasNext();) { Tensor.Cell cell = cellIterator.next(); @@ -214,122 +382,195 @@ class OperationMapper { TensorType outputType = Reduce.outputType(inputType, reduceDimensions); TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); - if (shouldKeepDimensions(tfNode)) { + if (shouldKeepDimensions(params)) { return reshape(outputFunction, outputType, keepDimensionType(inputType, reduceDimensions)); } - TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return output; + return Optional.of(output); } - private boolean shouldKeepDimensions(NodeDef tfNode) { - AttrValue keepDimsAttr = tfNode.getAttrMap().get("keep_dims"); - return keepDimsAttr != null && keepDimsAttr.getB(); + private static Optional<TypedTensorFunction> mul(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.multiply()); } - private TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) { - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension: inputType.dimensions()) { - String name = dimension.name(); - Long size = dimensionSize(dimension); - if (reduceDimensions.contains(name)) { - size = 1L; - } - builder.indexed(name, size); - } - return builder.build(); + private static Optional<TypedTensorFunction> rsqrt(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.rsqrt()); } - private TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) { - for (int i = 0; i < type.dimensions().size(); ++i) { - String correct = String.format("d%d", i); - String current = type.dimensions().get(i).name(); - if (!current.equals(correct)) { - return fixNamingConvention(type, function); - } + private static Optional<TypedTensorFunction> select(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 3)) { + return Optional.empty(); } - return new TypedTensorFunction(type, function); - } + Tensor condition = getConstantTensor(params, params.node().getInput(0)); - private TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) { - TensorType.Builder correctType = new TensorType.Builder(); - List<String> from = new ArrayList<>(); - List<String> to = new ArrayList<>(); - for (int i = 0; i < type.dimensions().size(); ++i) { - String correct = String.format("d%d", i); - String current = type.dimensions().get(i).name(); - if (!current.equals(correct)) { - from.add(current); - to.add(correct); - } - correctType.indexed(correct, dimensionSize(type.dimensions().get(i))); + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + TypedTensorFunction x = inputs.get(1).get(); + TypedTensorFunction y = inputs.get(2).get(); + if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) { + throw new IllegalArgumentException("'Select': input tensors must have the same shape"); } - if (from.size() > 0) { - function = new Rename(function, from, to); - type = correctType.build(); + + if (condition.type().rank() == 0) { + return Optional.of((int)condition.asDouble() == 0 ? y : x); } - return new TypedTensorFunction(type, function); + if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { + return Optional.of(condition.cellIterator().next().getValue().intValue() == 0 ? y : x); + } + + // The task is to select cells from 'x' or 'y' based on 'condition'. + // If 'condition' is 0 (false), select from 'y', if 1 (true) select + // from 'x'. We do this by individually joining 'x' and 'y' with + // 'condition', and then joining the resulting two tensors. + + Optional<TypedTensorFunction> conditionFunction = importConstantTensor(params, params.node().getInput(0)); + if (!conditionFunction.isPresent()) { + return Optional.empty(); + } + TensorFunction xCond = new Join(x.function(), conditionFunction.get().function(), ScalarFunctions.multiply()); + TensorFunction yCond = new Join(y.function(), conditionFunction.get().function(), new DoubleBinaryOperator() { + @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } + @Override public String toString() { return "f(a,b)(a * (1-b))"; } + }); + TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add()); + TypedTensorFunction output = new TypedTensorFunction(x.type(), outputFunction); + return Optional.of(output); } - TypedTensorFunction noOp(List<TypedTensorFunction> arguments) { - ensureArguments(1, arguments, "noOp"); - return arguments.get(0); + private static Optional<TypedTensorFunction> sigmoid(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.sigmoid()); } - TypedTensorFunction expandDims(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) { - ensureArguments(2, arguments, "expandDims"); - Tensor axis = getConstantTensor(model, tfNode.getInput(1)); - if (axis.type().rank() != 0) { - throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar"); + private static Optional<TypedTensorFunction> squaredDifference(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.squareddifference()); + } + + private static Optional<TypedTensorFunction> sub(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.subtract()); + } + + private static Optional<TypedTensorFunction> elu(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.elu()); + } + + private static Optional<TypedTensorFunction> relu(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.relu()); + } + + private static Optional<TypedTensorFunction> selu(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.selu()); + } + + private static Optional<TypedTensorFunction> softMax(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 1)) { + return Optional.empty(); } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + TypedTensorFunction a = inputs.get(0).get(); + // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 + String dimension = "d" + (a.type().rank() - 1); + Softmax softmax = new Softmax(a.function(), dimension); + TypedTensorFunction output = new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); + return Optional.of(output); + } - TensorFunction inputFunction = arguments.get(0).function(); - TensorType inputType = arguments.get(0).type(); + private static Optional<TypedTensorFunction> variable(TensorFlowImporter.Parameters params) { + return importConstantTensor(params, params.node().getName()); + } - int dimensionToInsert = (int)axis.asDouble(); - if (dimensionToInsert < 0) { - dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; + private static Optional<TypedTensorFunction> noOp(TensorFlowImporter.Parameters params) { + return Optional.empty(); + } + + /* + * Utility + */ + + private static Optional<TypedTensorFunction> join(TensorFlowImporter.Parameters params, DoubleBinaryOperator doubleFunction) { + if (!checkInputs(params, 2)) { + return Optional.empty(); } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); - TensorType.Builder outputTypeBuilder = new TensorType.Builder(); - int dimensionIndex = 0; - for (int i = 0; i < inputType.dimensions().size() + 1; ++i) { - String name = String.format("temp_%d", i); - Long size; - if (i == dimensionToInsert) { - size = 1L; - } else { - size = dimensionSize(inputType.dimensions().get(dimensionIndex)); - dimensionIndex++; - } - outputTypeBuilder.indexed(name, size); + TypedTensorFunction a = inputs.get(0).get(); + TypedTensorFunction b = inputs.get(1).get(); + + if (a.type().rank() == 0 && b.type().rank() > 0) { + return Optional.of(new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction))); + } + if (b.type().rank() == 0 && a.type().rank() > 0) { + return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction))); + } + if (a.type().rank() == b.type().rank()) { + return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction))); } - return reshape(inputFunction, inputType, outputTypeBuilder.build()); + // Well now we have entered the wonderful world of "broadcasting" + // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html + // I'm not able to extract from that any unambiguous specification of which dimensions + // should be "stretched" when the tensor do not have the same number of dimensions. + // From trying this with TensorFlow it appears that the second tensor is matched to the + // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. + // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). + + if (a.type().rank() > b.type().rank()) { + TensorFunction renameFunction = renameForBroadcast(a, b); + return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction))); + } + TensorFunction renameFunction = renameForBroadcast(b, a); + return Optional.of(new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction))); + } + + private static TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) { + List<String> renameFrom = new ArrayList<>(); + List<String> renameTo = new ArrayList<>(); + int sizeDifference = a.type().rank() - b.type().rank(); + for (int i = 0; i < b.type().rank(); i++) { + renameFrom.add(b.type().dimensions().get(i).name()); + renameTo.add("d" + (sizeDifference + i)); + } + return new Rename(b.function(), renameFrom, renameTo); } - TypedTensorFunction reshape(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) { - ensureArguments(2, arguments, "reshape"); - Tensor shape = getConstantTensor(model, tfNode.getInput(1)); + private static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params, DoubleUnaryOperator doubleFunction) { + if (!checkInputs(params, 1)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + TypedTensorFunction a = inputs.get(0).get(); + TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); + com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); + return Optional.of(new TypedTensorFunction(resultType, function)); + } - TensorFunction inputFunction = arguments.get(0).function(); - TensorType inputType = arguments.get(0).type(); + private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) { + params.result().constant(params.node().getName(), constant); + TypedTensorFunction output = new TypedTensorFunction(constant.type(), + new TensorFunctionNode.TensorFunctionExpressionNode( + new ReferenceNode("constant(\"" + params.node().getName() + "\")"))); + return Optional.of(output); + } - TensorType.Builder outputTypeBuilder = new TensorType.Builder(); - int dimensionIndex = 0; - for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int size = cell.getValue().intValue(); - if (size < 0) { - size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue(); - } - outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size); - dimensionIndex++; + private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) { + if (params.result().constants().containsKey(name)) { + return params.result().constants().get(name); } - return reshape(inputFunction, inputType, outputTypeBuilder.build()); + Session.Runner fetched = params.model().session().runner().fetch(name); + List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); + if (importedTensors.size() != 1) + throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " + + importedTensors.size()); + return TensorConverter.toVespaTensor(importedTensors.get(0)); + } + + private static Optional<TypedTensorFunction> importConstantTensor(TensorFlowImporter.Parameters params, String name) { + AttrValue shapes = params.node().getAttrMap().get("_output_shapes"); + if (shapes == null) + throw new IllegalArgumentException("'" + name + "' is missing a tensor shape"); + Tensor constant = getConstantTensor(params, name); + return createConstant(params, constant); } - private TypedTensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { + private static Optional<TypedTensorFunction> reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { if (!tensorSize(inputType).equals(tensorSize(outputType))) { throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); } @@ -353,10 +594,10 @@ class OperationMapper { Reduce.Aggregator.sum, inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return output; + return Optional.of(output); } - private ExpressionNode unrollTensorExpression(TensorType type) { + private static ExpressionNode unrollTensorExpression(TensorType type) { if (type.rank() == 0) { return new ConstantNode(DoubleValue.zero); } @@ -378,80 +619,56 @@ class OperationMapper { return new ArithmeticNode(children, operators); } - TypedTensorFunction select(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result, List<TypedTensorFunction> arguments) { - ensureArguments(3, arguments, "select"); - Tensor condition = getConstantTensor(model, tfNode.getInput(0)); - - TypedTensorFunction x = arguments.get(1); - TypedTensorFunction y = arguments.get(2); - if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) { - throw new IllegalArgumentException("'Select': input tensors must have the same shape"); - } + private static boolean shouldKeepDimensions(TensorFlowImporter.Parameters params) { + AttrValue keepDimsAttr = params.node().getAttrMap().get("keep_dims"); + return keepDimsAttr != null && keepDimsAttr.getB(); + } - if (condition.type().rank() == 0) { - return (int)condition.asDouble() == 0 ? y : x; - } - if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { - return condition.cellIterator().next().getValue().intValue() == 0 ? y : x; + private static TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension: inputType.dimensions()) { + String name = dimension.name(); + Long size = dimensionSize(dimension); + if (reduceDimensions.contains(name)) { + size = 1L; + } + builder.indexed(name, size); } - - // The task is to select cells from 'x' or 'y' based on 'condition'. - // If 'condition' is 0 (false), select from 'y', if 1 (true) select - // from 'x'. We do this by individually joining 'x' and 'y' with - // 'condition', and then joining the resulting two tensors. - - TypedTensorFunction conditionFunction = importConstantTensor(tfNode, model, result, tfNode.getInput(0)); - TensorFunction xCond = new Join(x.function(), conditionFunction.function(), ScalarFunctions.multiply()); - TensorFunction yCond = new Join(y.function(), conditionFunction.function(), new DoubleBinaryOperator() { - @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } - @Override public String toString() { return "f(a,b)(a * (1-b))"; } - }); - TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add()); - return new TypedTensorFunction(x.type(), outputFunction); + return builder.build(); } - TypedTensorFunction softmax(List<TypedTensorFunction> arguments) { - ensureArguments(1, arguments, "softmax"); - TypedTensorFunction a = arguments.get(0); - // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 - String dimension = "d" + (a.type().rank() - 1); - Softmax softmax = new Softmax(a.function(), dimension); - return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); + private static TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) { + for (int i = 0; i < type.dimensions().size(); ++i) { + String correct = String.format("d%d", i); + String current = type.dimensions().get(i).name(); + if (!current.equals(correct)) { + return fixNamingConvention(type, function); + } + } + return new TypedTensorFunction(type, function); } - TypedTensorFunction squeeze(NodeDef tfNode, List<TypedTensorFunction> arguments) { - ensureArguments(1, arguments, "squeeze"); - - TensorFunction inputFunction = arguments.get(0).function(); - TensorType inputType = arguments.get(0).type(); - List<String> squeezeDimensions; - - AttrValue squeezeDimsAttr = tfNode.getAttrMap().get("squeeze_dims"); - if (squeezeDimsAttr == null) { - squeezeDimensions = inputType.dimensions().stream(). - filter(dim -> dimensionSize(dim) == 1). - map(TensorType.Dimension::name). - collect(Collectors.toList()); - } else { - squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). - map(i -> i < 0 ? inputType.dimensions().size() - i : i). - map(i -> inputType.dimensions().get(i.intValue())). - filter(dim -> dimensionSize(dim) == 1). - map(TensorType.Dimension::name). - collect(Collectors.toList()); + private static TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) { + TensorType.Builder correctType = new TensorType.Builder(); + List<String> from = new ArrayList<>(); + List<String> to = new ArrayList<>(); + for (int i = 0; i < type.dimensions().size(); ++i) { + String correct = String.format("d%d", i); + String current = type.dimensions().get(i).name(); + if (!current.equals(correct)) { + from.add(current); + to.add(correct); + } + correctType.indexed(correct, dimensionSize(type.dimensions().get(i))); } - - if (squeezeDimensions.isEmpty()) { - return arguments.get(0); + if (from.size() > 0) { + function = new Rename(function, from, to); + type = correctType.build(); } - - TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); - TensorType outputType = Reduce.outputType(inputType, squeezeDimensions); - TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return output; + return new TypedTensorFunction(type, function); } - private Long tensorSize(TensorType type) { + private static Long tensorSize(TensorType type) { Long size = 1L; for (TensorType.Dimension dimension : type.dimensions()) { size *= dimensionSize(dimension); @@ -459,14 +676,21 @@ class OperationMapper { return size; } - private Long dimensionSize(TensorType.Dimension dim) { + private static Long dimensionSize(TensorType.Dimension dim) { return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); } - private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) { - if ( arguments.size() != count) - throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName + - ", but got " + arguments.size()); + private static boolean checkInputs(TensorFlowImporter.Parameters params, int expected) { + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + if (!inputs.stream().allMatch(Optional::isPresent)) { + return false; + } + if (inputs.size() != expected) { + params.signature().importWarning("Expected " + expected + + " arguments to " + params.node().getOp() + ", but got " + inputs.size()); + return false; + } + return true; } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java index ca880e6f310..b88ffce275a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java @@ -13,11 +13,13 @@ import java.nio.LongBuffer; /** + * Converts TensorFlow tensors into Vespa tensors. + * * @author bratseth */ public class TensorConverter { - public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { + public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { TensorType type = toVespaTensorType(tfTensor.shape()); Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); @@ -26,7 +28,7 @@ public class TensorConverter { return builder.build(); } - private TensorType toVespaTensorType(long[] shape) { + private static TensorType toVespaTensorType(long[] shape) { TensorType.Builder b = new TensorType.Builder(); int dimensionIndex = 0; for (long dimensionSize : shape) { @@ -36,7 +38,7 @@ public class TensorConverter { return b.build(); } - private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { + private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { switch (tfTensor.dataType()) { case DOUBLE: return new DoubleValues(tfTensor); case FLOAT: return new FloatValues(tfTensor); @@ -149,4 +151,5 @@ public class TensorConverter { } } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index b9e244a3e08..3a6b3f23a1d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -4,7 +4,6 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; import org.tensorflow.framework.GraphDef; @@ -12,12 +11,14 @@ import org.tensorflow.framework.MetaGraphDef; import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.SignatureDef; import org.tensorflow.framework.TensorInfo; -import org.tensorflow.framework.TensorShapeProto; import java.io.File; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; /** @@ -27,8 +28,6 @@ import java.util.stream.Collectors; */ public class TensorFlowImporter { - private final OperationMapper operationMapper = new OperationMapper(); - /** * Imports a saved TensorFlow model from a directory. * The model should be saved as a .pbtxt or .pb file. @@ -68,9 +67,21 @@ public class TensorFlowImporter { for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { String outputName = output.getKey(); try { - NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef()); - importNode(node, graph.getGraphDef(), model, result); - signature.output(outputName, nameOf(output.getValue().getName())); + NodeDef node = getNode(namePartOf(output.getValue().getName()), graph.getGraphDef()); + Parameters params = createParameters(graph.getGraphDef(), model, result, signature, node, ""); + + // Commonly, there are multiple paths through a TensorFlow graph, for instance for + // training and testing/evaluation. Examples are dropout and batch norm. For Vespa + // we are not concerned with training paths, so we can ignore non-supported operations + // as long as they are on a path that will not be evaluated run time. Operations + // that fail import will not have a value present in the optionals. However, the + // final output node must have value present. It is an error if it does not. + + Optional<TypedTensorFunction> outputFunction = importNode(params); + if (!outputFunction.isPresent()) { + throw new IllegalArgumentException(signature.importWarnings().stream().collect(Collectors.joining("\n"))); + } + signature.output(outputName, namePartOf(output.getValue().getName())); } catch (IllegalArgumentException e) { signature.skippedOutput(outputName, Exceptions.toMessageString(e)); @@ -82,92 +93,59 @@ public class TensorFlowImporter { private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) { inputInfoMap.forEach((key, value) -> { - String argumentName = nameOf(value.getName()); - TensorType argumentType = importTensorType(value.getTensorShape()); + String argumentName = namePartOf(value.getName()); + TensorType argumentType = AttrValueConverter.toVespaTensorType(value.getTensorShape()); // Arguments are (Placeholder) nodes, so not local to the signature: signature.owner().argument(argumentName, argumentType); signature.input(key, argumentName); }); } - private TensorType importTensorType(TensorShapeProto tensorShape) { - TensorType.Builder b = new TensorType.Builder(); - for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) { - int dimensionSize = (int)dimension.getSize(); - if (dimensionSize >= 0) - b.indexed("d" + b.rank(), dimensionSize); - else - b.indexed("d" + b.rank()); // unbound size + /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ + private Optional<TypedTensorFunction> importNode(Parameters params) { + String nodeName = params.node().getName(); + if (params.imported().containsKey(nodeName)) { + return Optional.of(params.imported().get(nodeName)); } - return b.build(); - } - /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) { - TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result); + Optional<TypedTensorFunction> function = OperationMapper.map(params); + if (!function.isPresent()) { + return Optional.empty(); + } + if (!controlDependenciesArePresent(params)) { + return Optional.empty(); + } + params.imported().put(nodeName, function.get()); + try { // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output // will be used. We parse the TensorFunction here to convert it to a RankingExpression tree - result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), function.function().toString())); + params.result().expression(nodeName, + new RankingExpression(params.node().getName(), function.get().function().toString())); return function; } catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function.function() + + throw new RuntimeException("Tensorflow function " + function.get().function() + " cannot be parsed as a ranking expression", e); } } + private boolean controlDependenciesArePresent(Parameters params) { + return params.node().getInputList().stream() + .filter(TensorFlowImporter::isControlDependency) + .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName)))) + .allMatch(Optional::isPresent); + } - - private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) { - // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops - // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ - switch (tfNode.getOp().toLowerCase()) { - // array ops - case "const" : return operationMapper.constant(tfNode, model, result); - case "expanddims" : return operationMapper.expandDims(tfNode, model, importArguments(tfNode, graph, model, result)); - case "identity" : return operationMapper.identity(tfNode, model, result); - case "placeholder" : return operationMapper.placeholder(tfNode, result); - case "placeholderwithdefault" : return operationMapper.placeholderWithDefault(tfNode, model, result); - case "reshape" : return operationMapper.reshape(tfNode, model, importArguments(tfNode, graph, model, result)); - case "squeeze" : return operationMapper.squeeze(tfNode, importArguments(tfNode, graph, model, result)); - - // math ops - case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); - case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos()); - case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result)); - case "maximum" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.max()); - case "mean" : case "reducemean": return operationMapper.mean(tfNode, model, importArguments(tfNode, graph, model, result)); - case "multiply": case "mul" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.multiply()); - case "rsqrt": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.rsqrt()); - case "where3": case "select" : return operationMapper.select(tfNode, model, result, importArguments(tfNode, graph, model, result)); - case "sigmoid": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.sigmoid()); - case "squareddifference" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.squareddifference()); - case "subtract" : case "sub" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.subtract()); - - // nn ops - case "biasadd" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); - case "elu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.elu()); - case "relu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.relu()); - case "selu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.selu()); - case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result)); - - // evaluation no-ops - case "stopgradient" : - case "noop": - return operationMapper.noOp(importArguments(tfNode, graph, model, result)); - - // not supported - default : - throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported (" + tfNode.getName() + ")"); - } + private static boolean isControlDependency(String nodeName) { + return nodeName.startsWith("^"); } - private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, - TensorFlowModel result) { - return tfNode.getInputList().stream() - .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) - .collect(Collectors.toList()); + private List<Optional<TypedTensorFunction>> importArguments(Parameters params) { + return params.node().getInputList().stream() + .filter(nodeName -> !isControlDependency(nodeName)) + .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName)))) + .collect(Collectors.toList()); } private NodeDef getNode(String name, GraphDef graph) { @@ -181,8 +159,94 @@ public class TensorFlowImporter { * A method signature input and output has the form name:index. * This returns the name part without the index. */ - private String nameOf(String name) { + private static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; return name.split(":")[0]; } + /** + * This return the index part. Indexes are used for nodes with + * multiple outputs. + */ + private static String indexPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? "" : name.substring(i + 1); + } + + + private Parameters createParameters(GraphDef graph, + SavedModelBundle model, + TensorFlowModel result, + TensorFlowModel.Signature signature, + NodeDef node, + String port) { + return new Parameters(this, graph, model, result, signature, new HashMap<>(), node, port); + } + + /** Parameter object to hold important data while importing */ + static final class Parameters { + private final TensorFlowImporter owner; + private final GraphDef graph; + private final SavedModelBundle model; + private final TensorFlowModel result; + private final TensorFlowModel.Signature signature; + private final Map<String, TypedTensorFunction> imported; + private final NodeDef node; + private final String port; + + private Parameters(TensorFlowImporter owner, + GraphDef graph, + SavedModelBundle model, + TensorFlowModel result, + TensorFlowModel.Signature signature, + Map<String, TypedTensorFunction> imported, + NodeDef node, + String port) { + this.owner = owner; + this.graph = graph; + this.model = model; + this.result = result; + this.signature = signature; + this.imported = imported; + this.node = node; + this.port = port; + } + + GraphDef graph() { + return this.graph; + } + + SavedModelBundle model() { + return this.model; + } + + TensorFlowModel result() { + return this.result; + } + + TensorFlowModel.Signature signature() { + return this.signature; + } + + Map<String, TypedTensorFunction> imported() { + return this.imported; + } + + NodeDef node() { + return node; + } + + String port() { + return port; + } + + Parameters copy(NodeDef node, String port) { + return new Parameters(this.owner, this.graph, this.model, this.result, this.signature, this.imported, node, port); + } + + List<Optional<TypedTensorFunction>> inputs() { + return owner.importArguments(this); + } + } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 1a6c93384ea..60aaf8ddce1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -5,8 +5,10 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -67,6 +69,7 @@ public class TensorFlowModel { private final Map<String, String> inputs = new HashMap<>(); private final Map<String, String> outputs = new HashMap<>(); private final Map<String, String> skippedOutputs = new HashMap<>(); + private final List<String> importWarnings = new ArrayList<>(); Signature(String name) { this.name = name; @@ -75,6 +78,7 @@ public class TensorFlowModel { void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } void output(String name, String expressionName) { outputs.put(name, expressionName); } void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } + void importWarning(String warning) { importWarnings.add(warning); } public String name() { return name; } @@ -99,6 +103,11 @@ public class TensorFlowModel { */ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } + /** + * Returns an immutable list of possibly non-fatal warnings encountered during import. + */ + public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } + /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java index 962f9dda0a6..600225bfe76 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java @@ -19,7 +19,12 @@ final class TypedTensorFunction { this.function = function; } - public TensorType type() { return type; } - public TensorFunction function() { return function; } + public TensorType type() { + return type; + } + + public TensorFunction function() { + return function; + } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java index 746ca3b3200..7485ce69f98 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.transform; import com.yahoo.searchlib.rankingexpression.evaluation.Value; diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py b/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py new file mode 100644 index 00000000000..adbf29b9ab6 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py @@ -0,0 +1,40 @@ + +# Common imports +import numpy as np +import tensorflow as tf +import datetime + +now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S") +root_logdir = "tf_logs" +logdir = "{}/run-{}/".format(root_logdir, now) + +n_inputs = 784 +n_outputs = 10 +dropout_rate = 0.5 # == 1 - keep_prob + +X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X") +y = tf.placeholder(tf.int64, shape=(None), name="y") +training = tf.placeholder_with_default(False, shape=(), name='training') + +X_drop = tf.layers.dropout(X, dropout_rate, training=training, name="xdrop") +output = tf.layers.dense(X_drop, n_outputs, name="outputs") + +init = tf.global_variables_initializer() +file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph()) + +with tf.Session() as sess: + init.run() + sess.run(output, feed_dict={training: False, X: np.random.random((1,784))}) + + export_path = "saved" + print('Exporting trained model to ', export_path) + builder = tf.saved_model.builder.SavedModelBuilder(export_path) + signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':X}, outputs = {'y':output}) + builder.add_meta_graph_and_variables(sess, + [tf.saved_model.tag_constants.SERVING], + signature_def_map={'serving_default':signature}) + builder.save(as_text=True) + +file_writer.close() + + diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt new file mode 100644 index 00000000000..52ae5e77a40 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt @@ -0,0 +1,2756 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "Assign" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + attr { + name: "validate_shape" + type: "bool" + default_value { + b: true + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: true + } + } + allows_uninitialized_input: true + } + op { + name: "BiasAdd" + input_arg { + name: "value" + type_attr: "T" + } + input_arg { + name: "bias" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "Floor" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Merge" + input_arg { + name: "inputs" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + output_arg { + name: "value_index" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + is_stateful: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "PlaceholderWithDefault" + input_arg { + name: "input" + type_attr: "dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + } + } + op { + name: "RandomUniform" + input_arg { + name: "shape" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "seed" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "seed2" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true + } + op { + name: "RealDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "Shape" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Switch" + input_arg { + name: "data" + type_attr: "T" + } + input_arg { + name: "pred" + type: DT_BOOL + } + output_arg { + name: "output_false" + type_attr: "T" + } + output_arg { + name: "output_true" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "VariableV2" + output_arg { + name: "ref" + type_attr: "dtype" + is_ref: true + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + } + } + tags: "serve" + tensorflow_version: "1.4.1" + tensorflow_git_version: "v1.4.0-19-ga52c8d9" + } + graph_def { + node { + name: "X" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + node { + name: "y" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "shape" + value { + shape { + unknown_rank: true + } + } + } + } + node { + name: "training/input" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: false + } + } + } + } + node { + name: "training" + op: "PlaceholderWithDefault" + input: "training/input" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "xdrop/cond/Switch" + op: "Switch" + input: "training" + input: "training" + attr { + key: "T" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + shape { + } + } + } + } + } + node { + name: "xdrop/cond/switch_t" + op: "Identity" + input: "xdrop/cond/Switch:1" + attr { + key: "T" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "xdrop/cond/switch_f" + op: "Identity" + input: "xdrop/cond/Switch" + attr { + key: "T" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "xdrop/cond/pred_id" + op: "Identity" + input: "training" + attr { + key: "T" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "xdrop/cond/dropout/keep_prob" + op: "Const" + input: "^xdrop/cond/switch_t" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node { + name: "xdrop/cond/dropout/Shape/Switch" + op: "Switch" + input: "X" + input: "xdrop/cond/pred_id" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@X" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "xdrop/cond/dropout/Shape" + op: "Shape" + input: "xdrop/cond/dropout/Shape/Switch:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "xdrop/cond/dropout/random_uniform/min" + op: "Const" + input: "^xdrop/cond/switch_t" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node { + name: "xdrop/cond/dropout/random_uniform/max" + op: "Const" + input: "^xdrop/cond/switch_t" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node { + name: "xdrop/cond/dropout/random_uniform/RandomUniform" + op: "RandomUniform" + input: "xdrop/cond/dropout/Shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node { + name: "xdrop/cond/dropout/random_uniform/sub" + op: "Sub" + input: "xdrop/cond/dropout/random_uniform/max" + input: "xdrop/cond/dropout/random_uniform/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "xdrop/cond/dropout/random_uniform/mul" + op: "Mul" + input: "xdrop/cond/dropout/random_uniform/RandomUniform" + input: "xdrop/cond/dropout/random_uniform/sub" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "xdrop/cond/dropout/random_uniform" + op: "Add" + input: "xdrop/cond/dropout/random_uniform/mul" + input: "xdrop/cond/dropout/random_uniform/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "xdrop/cond/dropout/add" + op: "Add" + input: "xdrop/cond/dropout/keep_prob" + input: "xdrop/cond/dropout/random_uniform" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "xdrop/cond/dropout/Floor" + op: "Floor" + input: "xdrop/cond/dropout/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "xdrop/cond/dropout/div" + op: "RealDiv" + input: "xdrop/cond/dropout/Shape/Switch:1" + input: "xdrop/cond/dropout/keep_prob" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "xdrop/cond/dropout/mul" + op: "Mul" + input: "xdrop/cond/dropout/div" + input: "xdrop/cond/dropout/Floor" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "xdrop/cond/Identity/Switch" + op: "Switch" + input: "X" + input: "xdrop/cond/pred_id" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@X" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "xdrop/cond/Identity" + op: "Identity" + input: "xdrop/cond/Identity/Switch" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "xdrop/cond/Merge" + op: "Merge" + input: "xdrop/cond/Identity" + input: "xdrop/cond/dropout/mul" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + } + } + } + } + } + node { + name: "outputs/kernel/Initializer/random_uniform/shape" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\020\003\000\000\n\000\000\000" + } + } + } + } + node { + name: "outputs/kernel/Initializer/random_uniform/min" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -0.08692913502454758 + } + } + } + } + node { + name: "outputs/kernel/Initializer/random_uniform/max" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.08692913502454758 + } + } + } + } + node { + name: "outputs/kernel/Initializer/random_uniform/RandomUniform" + op: "RandomUniform" + input: "outputs/kernel/Initializer/random_uniform/shape" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node { + name: "outputs/kernel/Initializer/random_uniform/sub" + op: "Sub" + input: "outputs/kernel/Initializer/random_uniform/max" + input: "outputs/kernel/Initializer/random_uniform/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "outputs/kernel/Initializer/random_uniform/mul" + op: "Mul" + input: "outputs/kernel/Initializer/random_uniform/RandomUniform" + input: "outputs/kernel/Initializer/random_uniform/sub" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "outputs/kernel/Initializer/random_uniform" + op: "Add" + input: "outputs/kernel/Initializer/random_uniform/mul" + input: "outputs/kernel/Initializer/random_uniform/min" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "outputs/kernel" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "outputs/kernel/Assign" + op: "Assign" + input: "outputs/kernel" + input: "outputs/kernel/Initializer/random_uniform" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "outputs/kernel/read" + op: "Identity" + input: "outputs/kernel" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "outputs/bias/Initializer/zeros" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "outputs/bias" + op: "VariableV2" + attr { + key: "_class" + value { + list { + s: "loc:@outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "outputs/bias/Assign" + op: "Assign" + input: "outputs/bias" + input: "outputs/bias/Initializer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "outputs/bias/read" + op: "Identity" + input: "outputs/bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "outputs/MatMul" + op: "MatMul" + input: "xdrop/cond/Merge" + input: "outputs/kernel/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "outputs/BiasAdd" + op: "BiasAdd" + input: "outputs/MatMul" + input: "outputs/bias/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + } + node { + name: "init" + op: "NoOp" + input: "^outputs/kernel/Assign" + input: "^outputs/bias/Assign" + } + node { + name: "save/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "model" + } + } + } + } + node { + name: "save/StringJoin/inputs_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp_8370883d2d9a4584b706fa987019b91d/part" + } + } + } + } + node { + name: "save/StringJoin" + op: "StringJoin" + input: "save/Const" + input: "save/StringJoin/inputs_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "separator" + value { + s: "" + } + } + } + node { + name: "save/num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "save/ShardedFilename/shard" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "save/ShardedFilename" + op: "ShardedFilename" + input: "save/StringJoin" + input: "save/ShardedFilename/shard" + input: "save/num_shards" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/SaveV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "outputs/bias" + string_val: "outputs/kernel" + } + } + } + } + node { + name: "save/SaveV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "" + string_val: "" + } + } + } + } + node { + name: "save/SaveV2" + op: "SaveV2" + input: "save/ShardedFilename" + input: "save/SaveV2/tensor_names" + input: "save/SaveV2/shape_and_slices" + input: "outputs/bias" + input: "outputs/kernel" + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + } + node { + name: "save/control_dependency" + op: "Identity" + input: "save/ShardedFilename" + input: "^save/SaveV2" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@save/ShardedFilename" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "save/ShardedFilename" + input: "^save/control_dependency" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "save/MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "save/MergeV2Checkpoints/checkpoint_prefixes" + input: "save/Const" + attr { + key: "delete_old_dirs" + value { + b: true + } + } + } + node { + name: "save/Identity" + op: "Identity" + input: "save/Const" + input: "^save/control_dependency" + input: "^save/MergeV2Checkpoints" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/RestoreV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "outputs/bias" + } + } + } + } + node { + name: "save/RestoreV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2/tensor_names" + input: "save/RestoreV2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign" + op: "Assign" + input: "outputs/bias" + input: "save/RestoreV2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@outputs/bias" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_1/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "outputs/kernel" + } + } + } + } + node { + name: "save/RestoreV2_1/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_1" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_1/tensor_names" + input: "save/RestoreV2_1/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_1" + op: "Assign" + input: "outputs/kernel" + input: "save/RestoreV2_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@outputs/kernel" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/restore_shard" + op: "NoOp" + input: "^save/Assign" + input: "^save/Assign_1" + } + node { + name: "save/restore_all" + op: "NoOp" + input: "^save/restore_shard" + } + versions { + producer: 24 + } + } + saver_def { + filename_tensor_name: "save/Const:0" + save_tensor_name: "save/Identity:0" + restore_op_name: "save/restore_all" + max_to_keep: 5 + sharded: true + keep_checkpoint_every_n_hours: 10000.0 + version: V2 + } + collection_def { + key: "cond_context" + value { + bytes_list { + value: "\n\024xdrop/cond/cond_text\022\024xdrop/cond/pred_id:0\032\025xdrop/cond/switch_t:0 \001*\241\004\n\003X:0\n\032xdrop/cond/dropout/Floor:0\n!xdrop/cond/dropout/Shape/Switch:1\n\032xdrop/cond/dropout/Shape:0\n\030xdrop/cond/dropout/add:0\n\030xdrop/cond/dropout/div:0\n\036xdrop/cond/dropout/keep_prob:0\n\030xdrop/cond/dropout/mul:0\n1xdrop/cond/dropout/random_uniform/RandomUniform:0\n\'xdrop/cond/dropout/random_uniform/max:0\n\'xdrop/cond/dropout/random_uniform/min:0\n\'xdrop/cond/dropout/random_uniform/mul:0\n\'xdrop/cond/dropout/random_uniform/sub:0\n#xdrop/cond/dropout/random_uniform:0\n\024xdrop/cond/pred_id:0\n\025xdrop/cond/switch_t:0\022(\n\003X:0\022!xdrop/cond/dropout/Shape/Switch:1" + value: "\n\026xdrop/cond/cond_text_1\022\024xdrop/cond/pred_id:0\032\025xdrop/cond/switch_f:0*\214\001\n\003X:0\n\034xdrop/cond/Identity/Switch:0\n\025xdrop/cond/Identity:0\n\024xdrop/cond/pred_id:0\n\025xdrop/cond/switch_f:0\022#\n\003X:0\022\034xdrop/cond/Identity/Switch:0" + } + } + } + collection_def { + key: "trainable_variables" + value { + bytes_list { + value: "\n\020outputs/kernel:0\022\025outputs/kernel/Assign\032\025outputs/kernel/read:02+outputs/kernel/Initializer/random_uniform:0" + value: "\n\016outputs/bias:0\022\023outputs/bias/Assign\032\023outputs/bias/read:02 outputs/bias/Initializer/zeros:0" + } + } + } + collection_def { + key: "variables" + value { + bytes_list { + value: "\n\020outputs/kernel:0\022\025outputs/kernel/Assign\032\025outputs/kernel/read:02+outputs/kernel/Initializer/random_uniform:0" + value: "\n\016outputs/bias:0\022\023outputs/bias/Assign\032\023outputs/bias/read:02 outputs/bias/Initializer/zeros:0" + } + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "x" + value { + name: "X:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + outputs { + key: "y" + value { + name: "outputs/BiasAdd:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 Binary files differnew file mode 100644 index 00000000000..e1b1b015b9f --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index Binary files differnew file mode 100644 index 00000000000..04ace49d9e3 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java new file mode 100644 index 00000000000..c6ee586a78c --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java @@ -0,0 +1,29 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author lesters + */ +public class BatchNormImportTestCase { + + @Test + public void testBatchNormImport() { + TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/batch_norm/saved"); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); + + assertEquals("Has skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); + + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); + model.assertEqualResult("X", output.getName()); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java new file mode 100644 index 00000000000..b59b4750911 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -0,0 +1,31 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author lesters + */ +public class DropoutImportTestCase { + + @Test + public void testDropoutImport() { + TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved"); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); + + assertEquals("Has skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); + + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("outputs/BiasAdd", output.getName()); + assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs/kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs/bias\"), d0, d1), f(a,b)(a + b))", + output.getRoot().toString()); + model.assertEqualResult("X", output.getName()); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java new file mode 100644 index 00000000000..f12b9a2c628 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -0,0 +1,62 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class MnistSoftmaxImportTestCase { + + @Test + public void testMnistSoftmaxImport() { + TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved"); + + // Check constants + assertEquals(2, model.get().constants().size()); + + Tensor constant0 = model.get().constants().get("Variable"); + assertNotNull(constant0); + assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + constant0.type()); + assertEquals(7840, constant0.size()); + + Tensor constant1 = model.get().constants().get("Variable_1"); + assertNotNull(constant1); + assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + constant1.type()); + assertEquals(10, constant1.size()); + + // Check signatures + assertEquals(1, model.get().signatures().size()); + TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); + assertNotNull(signature); + + // ... signature inputs + assertEquals(1, signature.inputs().size()); + TensorType argument0 = signature.inputArgument("x"); + assertNotNull(argument0); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); + + // ... signature outputs + assertEquals(1, signature.outputs().size()); + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("add", output.getName()); + assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))", + output.getRoot().toString()); + + // Test execution + model.assertEqualResult("Placeholder", "Variable/read"); + model.assertEqualResult("Placeholder", "Variable_1/read"); + model.assertEqualResult("Placeholder", "MatMul"); + model.assertEqualResult("Placeholder", "add"); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java deleted file mode 100644 index 13d042ee5dd..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.junit.Test; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; - -import java.nio.FloatBuffer; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - -/** - * @author bratseth - */ -public class TensorflowImportTestCase { - - @Test - public void testMnistSoftmaxImport() { - String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - TensorFlowModel result = new TensorFlowImporter().importModel(model); - - // Check constants - assertEquals(2, result.constants().size()); - - Tensor constant0 = result.constants().get("Variable"); - assertNotNull(constant0); - assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), - constant0.type()); - assertEquals(7840, constant0.size()); - - Tensor constant1 = result.constants().get("Variable_1"); - assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d0", 10).build(), - constant1.type()); - assertEquals(10, constant1.size()); - - // Check signatures - assertEquals(1, result.signatures().size()); - TensorFlowModel.Signature signature = result.signatures().get("serving_default"); - assertNotNull(signature); - - // ... signature inputs - assertEquals(1, signature.inputs().size()); - TensorType argument0 = signature.inputArgument("x"); - assertNotNull(argument0); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); - - // ... signature outputs - assertEquals(1, signature.outputs().size()); - RankingExpression output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("add", output.getName()); - assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))", - toNonPrimitiveString(output)); - - // Test execution - assertEqualResult(model, result, "Placeholder", "Variable/read"); - assertEqualResult(model, result, "Placeholder", "Variable_1/read"); - assertEqualResult(model, result, "Placeholder", "MatMul"); - assertEqualResult(model, result, "Placeholder", "add"); - } - - @Test - public void testBatchNormImport() { - String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved"; - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - TensorFlowModel result = new TensorFlowImporter().importModel(model); - TensorFlowModel.Signature signature = result.signature("serving_default"); - - assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size()); - - RankingExpression output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); - assertEqualResult(model, result, "X", output.getName()); - - } - - private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) { - Tensor tfResult = tensorFlowExecute(model, inputName, operationName); - Context context = contextFrom(result); - Tensor placeholder = placeholderArgument(); - context.put(inputName, new TensorValue(placeholder)); - Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); - assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); - } - - private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { - Session.Runner runner = model.session().runner(); - org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784)); - runner.feed(inputName, placeholder); - List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); - assertEquals(1, results.size()); - return new TensorConverter().toVespaTensor(results.get(0)); - } - - private Context contextFrom(TensorFlowModel result) { - MapContext context = new MapContext(); - result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); - return context; - } - - private String toNonPrimitiveString(RankingExpression expression) { - // toString on the wrapping expression will map to primitives, which is harder to read - return ((TensorFunctionNode)expression.getRoot()).function().toString(); - } - - private Tensor placeholderArgument() { - int size = 784; - Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build()); - for (int i = 0; i < size; i++) - b.cell(0, 0, i); - return b.build(); - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java new file mode 100644 index 00000000000..127b63c66c9 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -0,0 +1,72 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +import java.nio.FloatBuffer; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * Helper for TensorFlow import tests: Imports a model and provides asserts on it. + * This currently assumes the TensorFlow model takes a single input of type tensor(d0[1],d1[784]) + * + * @author bratseth + */ +public class TestableTensorFlowModel { + + private SavedModelBundle tensorFlowModel; + private TensorFlowModel model; + + // Sizes of the input vector + private final int d0Size = 1; + private final int d1Size = 784; + + public TestableTensorFlowModel(String modelDir) { + tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); + model = new TensorFlowImporter().importModel(tensorFlowModel); + } + + public TensorFlowModel get() { return model; } + + public void assertEqualResult(String inputName, String operationName) { + Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); + Context context = contextFrom(model); + Tensor placeholder = placeholderArgument(); + context.put(inputName, new TensorValue(placeholder)); + Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); + assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); + } + + private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { + Session.Runner runner = model.session().runner(); + org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, + FloatBuffer.allocate(d0Size * d1Size)); + runner.feed(inputName, placeholder); + List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); + assertEquals(1, results.size()); + return new TensorConverter().toVespaTensor(results.get(0)); + } + + private Context contextFrom(TensorFlowModel result) { + MapContext context = new MapContext(); + result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + return context; + } + + private Tensor placeholderArgument() { + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); + for (int d0 = 0; d0 < d0Size; d0++) + for (int d1 = 0; d1 < d1Size; d1++) + b.cell(0, d0, d1); + return b.build(); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java index 1960c1fe876..c28873684ee 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.tensor; import com.fasterxml.jackson.databind.JsonNode; diff --git a/searchlib/src/tests/attribute/enumstore/enumstore_test.cpp b/searchlib/src/tests/attribute/enumstore/enumstore_test.cpp index fa59baa2bc5..daff432d68d 100644 --- a/searchlib/src/tests/attribute/enumstore/enumstore_test.cpp +++ b/searchlib/src/tests/attribute/enumstore/enumstore_test.cpp @@ -371,8 +371,10 @@ EnumStoreTest::testCompaction(bool hasPostings, bool disableReEnumerate) { // entrySize = 15 before alignment uint32_t entrySize = EnumStoreType::alignEntrySize(15); - uint32_t bufferSize = entrySize * 5; - EnumStoreType ses(bufferSize, hasPostings); + uint32_t initBufferSize = entrySize * 5; + EnumStoreType ses(initBufferSize, hasPostings); + // Note: Sizes of underlying data store buffers are power of 2. + uint32_t adjustedBufferSize = vespalib::roundUp2inN(initBufferSize) - RESERVED_BYTES; EnumIndex idx; std::vector<EnumIndex> indices; typename EnumStoreType::Type t = "foo"; @@ -385,18 +387,19 @@ EnumStoreTest::testCompaction(bool hasPostings, bool disableReEnumerate) // fill with unique values for (uint32_t i = 0; i < 5; ++i) { - EXPECT_TRUE(ses.getRemaining() == bufferSize - i * entrySize); + size_t expRemaining = adjustedBufferSize - i * entrySize; + EXPECT_EQUAL(expRemaining, ses.getRemaining()); ses.addEnum(uniques[i].c_str(), idx); ses.incRefCount(idx); EXPECT_TRUE(ses.getRefCount(idx)); indices.push_back(idx); } - EXPECT_EQUAL(0u, ses.getRemaining()); - EXPECT_EQUAL(0u, ses.getBuffer(0).remaining()); + EXPECT_EQUAL(32u, ses.getRemaining()); + EXPECT_EQUAL(32u, ses.getBuffer(0).remaining()); EXPECT_EQUAL(entrySize * 5 + RESERVED_BYTES, ses.getBuffer(0).size()); EXPECT_EQUAL(RESERVED_BYTES, ses.getBuffer(0).getDeadElems()); uint32_t failEntrySize = ses.getEntrySize("enum05"); - EXPECT_TRUE(failEntrySize > ses.getRemaining()); + EXPECT_EQUAL(16u, failEntrySize); // change from enum00 -> enum01 ses.decRefCount(indices[0]); @@ -525,7 +528,9 @@ EnumStoreTest::testReset(bool hasPostings) } ses.reset(builder); - EXPECT_EQUAL(RESERVED_BYTES, ses.getRemaining()); + // Note: Sizes of underlying data store buffers are power of 2. + EXPECT_EQUAL(524288u, ses.getCapacity()); + EXPECT_EQUAL(204272u, ses.getRemaining()); // check for old unique strings for (StringVector::iterator iter = uniques.begin(); iter != uniques.end(); ++iter) { @@ -597,7 +602,8 @@ EnumStoreTest::testHoldListAndGeneration() } } - EXPECT_EQUAL(0u, ses.getRemaining()); + // Note: Sizes of underlying data store buffers are power of 2. + EXPECT_EQUAL(432u, ses.getRemaining()); EXPECT_EQUAL(RESERVED_BYTES, ses.getBuffer(0).getDeadElems()); // remove all uniques @@ -657,7 +663,8 @@ EnumStoreTest::testMemoryUsage() // usage before inserting enums MemoryUsage usage = ses.getMemoryUsage(); EXPECT_EQUAL(ses.getNumUniques(), uint32_t(0)); - EXPECT_EQUAL(enumStoreAlign(200u) + RESERVED_BYTES, usage.allocatedBytes()); + // Note: Sizes of underlying data store buffers are power of 2. + EXPECT_EQUAL(vespalib::roundUp2inN(enumStoreAlign(200u) + RESERVED_BYTES), usage.allocatedBytes()); EXPECT_EQUAL(RESERVED_BYTES, usage.usedBytes()); EXPECT_EQUAL(RESERVED_BYTES, usage.deadBytes()); EXPECT_EQUAL(0u, usage.allocatedBytesOnHold()); @@ -672,7 +679,8 @@ EnumStoreTest::testMemoryUsage() // usage after inserting enums usage = ses.getMemoryUsage(); EXPECT_EQUAL(ses.getNumUniques(), num); - EXPECT_EQUAL(enumStoreAlign(200u) + RESERVED_BYTES, usage.allocatedBytes()); + // Note: Sizes of underlying data store buffers are power of 2. + EXPECT_EQUAL(vespalib::roundUp2inN(enumStoreAlign(200u) + RESERVED_BYTES), usage.allocatedBytes()); EXPECT_EQUAL(num * entrySize + RESERVED_BYTES, usage.usedBytes()); EXPECT_EQUAL(RESERVED_BYTES, usage.deadBytes()); EXPECT_EQUAL(0u, usage.allocatedBytesOnHold()); @@ -689,7 +697,8 @@ EnumStoreTest::testMemoryUsage() // usage after removing enums usage = ses.getMemoryUsage(); EXPECT_EQUAL(ses.getNumUniques(), num / 2); - EXPECT_EQUAL(enumStoreAlign(200u) + RESERVED_BYTES, usage.allocatedBytes()); + // Note: Sizes of underlying data store buffers are power of 2. + EXPECT_EQUAL(vespalib::roundUp2inN(enumStoreAlign(200u) + RESERVED_BYTES), usage.allocatedBytes()); EXPECT_EQUAL(num * entrySize + RESERVED_BYTES, usage.usedBytes()); EXPECT_EQUAL((num / 2) * entrySize + RESERVED_BYTES, usage.deadBytes()); EXPECT_EQUAL(0u, usage.allocatedBytesOnHold()); diff --git a/searchlib/src/tests/btree/btree_test.cpp b/searchlib/src/tests/btree/btree_test.cpp index 5795385250f..1f39c7315e8 100644 --- a/searchlib/src/tests/btree/btree_test.cpp +++ b/searchlib/src/tests/btree/btree_test.cpp @@ -1022,6 +1022,15 @@ Test::requireThatTreeIteratorAssignWorks() } } +size_t +adjustAllocatedBytes(size_t nodeCount, size_t nodeSize) +{ + // Note: Sizes of underlying data store buffers are power of 2. + size_t allocatedBytes = vespalib::roundUp2inN(nodeCount * nodeSize); + size_t adjustedNodeCount = allocatedBytes / nodeSize; + return adjustedNodeCount * nodeSize; +} + void Test::requireThatMemoryUsageIsCalculated() { @@ -1041,8 +1050,8 @@ Test::requireThatMemoryUsageIsCalculated() MemoryUsage mu; const uint32_t initialInternalNodes = 128u; const uint32_t initialLeafNodes = 128u; - mu.incAllocatedBytes(sizeof(INode) * initialInternalNodes); - mu.incAllocatedBytes(sizeof(LNode) * initialLeafNodes); + mu.incAllocatedBytes(adjustAllocatedBytes(initialInternalNodes, sizeof(INode))); + mu.incAllocatedBytes(adjustAllocatedBytes(initialLeafNodes, sizeof(LNode))); mu.incUsedBytes(sizeof(INode)); mu.incDeadBytes(sizeof(INode)); EXPECT_TRUE(assertMemoryUsage(mu, tm.getMemoryUsage())); @@ -1071,8 +1080,8 @@ Test::requireThatMemoryUsageIsCalculated() gh.incGeneration(); tm.trimHoldLists(gh.getFirstUsedGeneration()); mu = MemoryUsage(); - mu.incAllocatedBytes(sizeof(INode) * initialInternalNodes); - mu.incAllocatedBytes(sizeof(LNode) * initialLeafNodes); + mu.incAllocatedBytes(adjustAllocatedBytes(initialInternalNodes, sizeof(INode))); + mu.incAllocatedBytes(adjustAllocatedBytes(initialLeafNodes, sizeof(LNode))); mu.incUsedBytes(sizeof(INode) * 2); mu.incDeadBytes(sizeof(INode) * 2); mu.incUsedBytes(sizeof(LNode)); diff --git a/searchlib/src/tests/datastore/array_store/array_store_test.cpp b/searchlib/src/tests/datastore/array_store/array_store_test.cpp index fff4445890b..dab853305c6 100644 --- a/searchlib/src/tests/datastore/array_store/array_store_test.cpp +++ b/searchlib/src/tests/datastore/array_store/array_store_test.cpp @@ -316,7 +316,7 @@ TEST_F("require that used, onHold and dead memory usage is tracked for large arr TEST_F("require that address space usage is ratio between used clusters and number of possible clusters", NumberFixture(3)) { f.add({2,2}); - f.add({4,4,4}); + f.add({3,3,3}); // 1 cluster is reserved (buffer 0, offset 0). EXPECT_EQUAL(3u, f.store.addressSpaceUsage().used()); EXPECT_EQUAL(1u, f.store.addressSpaceUsage().dead()); @@ -324,11 +324,15 @@ TEST_F("require that address space usage is ratio between used clusters and numb /* * Expected limit is sum of allocated clusters for active buffers and * potentially allocated clusters for free buffers. If all buffers were - * free then the limit would be 4 Gi. Then we subtract clusters for 4 - * buffers that are not free, and add their actual number of allocated - * clusters (16 clusters per buffer). + * free then the limit would be 4 Gi. + * Then we subtract clusters for 4 buffers that are not free (arraySize=1,2,3 + largeArray), + * and add their actual number of allocated clusters (16 clusters per buffer). + * Note: arraySize=3 has 21 clusters as allocated buffer is rounded up to power of 2: + * 16 * 3 * sizeof(int) = 192 -> 256. + * allocated elements = 256 / sizeof(int) = 64. + * limit = 64 / 3 = 21. */ - size_t expLimit = fourgig - 4 * F1::EntryRefType::offsetSize() + 4 * 16; + size_t expLimit = fourgig - 4 * F1::EntryRefType::offsetSize() + 3 * 16 + 21; EXPECT_EQUAL(static_cast<double>(2)/ expLimit, f.store.addressSpaceUsage().usage()); EXPECT_EQUAL(expLimit, f.store.addressSpaceUsage().limit()); } diff --git a/searchlib/src/tests/datastore/datastore/datastore_test.cpp b/searchlib/src/tests/datastore/datastore/datastore_test.cpp index 2463439c47c..c3de2261745 100644 --- a/searchlib/src/tests/datastore/datastore/datastore_test.cpp +++ b/searchlib/src/tests/datastore/datastore/datastore_test.cpp @@ -11,6 +11,8 @@ LOG_SETUP("datastore_test"); namespace search { namespace datastore { +using vespalib::alloc::MemoryAllocator; + struct IntReclaimer { static void reclaim(int *) {} @@ -65,21 +67,22 @@ public: using GrowthStats = std::vector<int>; -constexpr float ALLOC_GROW_FACTOR = 0.5; +constexpr float ALLOC_GROW_FACTOR = 0.4; +constexpr size_t HUGE_PAGE_CLUSTER_SIZE = (MemoryAllocator::HUGEPAGE_SIZE / sizeof(int)); class GrowStore { - using Store = DataStoreT<EntryRefT<22>>; + using Store = DataStoreT<EntryRefT<24>>; using RefType = Store::RefType; Store _store; BufferType<int> _firstType; BufferType<int> _type; uint32_t _typeId; public: - GrowStore(size_t minSize, size_t minSwitch) + GrowStore(size_t minClusters, size_t maxClusters, size_t numClustersForNewBuffer) : _store(), - _firstType(1, 1, 64, 0, ALLOC_GROW_FACTOR), - _type(1, minSize, 64, minSwitch, ALLOC_GROW_FACTOR), + _firstType(1, 1, maxClusters, 0, ALLOC_GROW_FACTOR), + _type(1, minClusters, maxClusters, numClustersForNewBuffer, ALLOC_GROW_FACTOR), _typeId(0) { (void) _store.addType(&_firstType); @@ -460,11 +463,11 @@ namespace { void assertGrowStats(GrowthStats expSizes, GrowthStats expFirstBufSizes, size_t expInitMemUsage, - size_t minSize, size_t minSwitch) + size_t minClusters, size_t numClustersForNewBuffer, size_t maxClusters = 128) { - EXPECT_EQUAL(expSizes, GrowStore(minSize, minSwitch).getGrowthStats(expSizes.size())); - EXPECT_EQUAL(expFirstBufSizes, GrowStore(minSize, minSwitch).getFirstBufGrowStats()); - EXPECT_EQUAL(expInitMemUsage, GrowStore(minSize, minSwitch).getMemoryUsage().allocatedBytes()); + EXPECT_EQUAL(expSizes, GrowStore(minClusters, maxClusters, numClustersForNewBuffer).getGrowthStats(expSizes.size())); + EXPECT_EQUAL(expFirstBufSizes, GrowStore(minClusters, maxClusters, numClustersForNewBuffer).getFirstBufGrowStats()); + EXPECT_EQUAL(expInitMemUsage, GrowStore(minClusters, maxClusters, numClustersForNewBuffer).getMemoryUsage().allocatedBytes()); } } @@ -472,23 +475,29 @@ void assertGrowStats(GrowthStats expSizes, TEST("require that buffer growth works") { // Always switch to new buffer, min size 4 - TEST_DO(assertGrowStats({ 4, 4, 4, 6, 9, 13, 20, 30, 45, 64 }, + TEST_DO(assertGrowStats({ 4, 4, 4, 4, 8, 16, 16, 32, 64, 64 }, { 4 }, 20, 4, 0)); // Resize if buffer size is less than 4, min size 0 - TEST_DO(assertGrowStats({ 3, 3, 3, 4, 6, 9, 14, 21, 31, 47 }, - { 0, 1, 2, 3 }, 4, 0, 4)); + TEST_DO(assertGrowStats({ 4, 4, 4, 4, 8, 16, 16, 32, 64, 64 }, + { 0, 1, 2, 4 }, 4, 0, 4)); // Always switch to new buffer, min size 16 - TEST_DO(assertGrowStats({ 16, 16, 16, 24, 36, 54, 64, 64, 64 }, + TEST_DO(assertGrowStats({ 16, 16, 16, 32, 32, 64, 128, 128, 128 }, { 16 }, 68, 16, 0)); // Resize if buffer size is less than 16, min size 0 - TEST_DO(assertGrowStats({ 19, 19, 19, 28, 42, 63, 64, 64, 64 }, - { 0, 1, 2, 3, 4, 6, 9, 13, 19 }, 4, 0, 16)); + TEST_DO(assertGrowStats({ 16, 16, 16, 32, 32, 64, 128, 128, 128 }, + { 0, 1, 2, 4, 8, 16 }, 4, 0, 16)); // Resize if buffer size is less than 16, min size 4 - TEST_DO(assertGrowStats({ 19, 19, 19, 28, 42, 63, 64, 64, 64 }, - { 4, 6, 9, 13, 19 }, 20, 4, 16)); + TEST_DO(assertGrowStats({ 16, 16, 16, 32, 32, 64, 128, 128, 128 }, + { 4, 8, 16 }, 20, 4, 16)); // Always switch to new buffer, min size 0 - TEST_DO(assertGrowStats({ 1, 1, 1, 1, 2, 3, 4, 6, 9 }, + TEST_DO(assertGrowStats({ 1, 1, 1, 1, 1, 2, 2, 4, 8, 8, 16, 32 }, { 0, 1 }, 4, 0, 0)); + + // Buffers with sizes larger than the huge page size of the mmap allocator. + ASSERT_EQUAL(524288u, HUGE_PAGE_CLUSTER_SIZE); + TEST_DO(assertGrowStats({ 262144, 262144, 262144, 524288, 524288, 524288 * 2, 524288 * 3, 524288 * 4, 524288 * 5, 524288 * 5 }, + { 0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144 }, + 4, 0, HUGE_PAGE_CLUSTER_SIZE / 2, HUGE_PAGE_CLUSTER_SIZE * 5)); } } diff --git a/searchlib/src/tests/memoryindex/memoryindex/memoryindex_test.cpp b/searchlib/src/tests/memoryindex/memoryindex/memoryindex_test.cpp index 77a687796b3..9de6ac9f310 100644 --- a/searchlib/src/tests/memoryindex/memoryindex/memoryindex_test.cpp +++ b/searchlib/src/tests/memoryindex/memoryindex/memoryindex_test.cpp @@ -381,7 +381,7 @@ TEST("requireThatNumDocsAndDocIdLimitIsReturned") TEST("requireThatWeUnderstandTheMemoryFootprint") { - constexpr size_t BASE_SIZE = 118860u; + constexpr size_t BASE_SIZE = 188172u; { Setup setup; Index index(setup); diff --git a/searchlib/src/tests/sort/sort_test.cpp b/searchlib/src/tests/sort/sort_test.cpp index ac2c22a0035..1c9ea0efda0 100644 --- a/searchlib/src/tests/sort/sort_test.cpp +++ b/searchlib/src/tests/sort/sort_test.cpp @@ -62,7 +62,7 @@ void Test::testIcu() { { const std::string src("Creation of Bob2007 this is atumated string\this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string;this is atumated string; _ 12345567890-=,./;'[;"); - std::vector<uint16_t> u16Buffer(100); + std::vector<UChar> u16Buffer(100); UErrorCode status = U_ZERO_ERROR; int32_t u16Wanted(0); u_strFromUTF8(&u16Buffer[0], u16Buffer.size(), &u16Wanted, src.c_str(), -1, &status); diff --git a/searchlib/src/tests/sort/uca.cpp b/searchlib/src/tests/sort/uca.cpp index 579f3e7906e..d7a2f66c2d8 100644 --- a/searchlib/src/tests/sort/uca.cpp +++ b/searchlib/src/tests/sort/uca.cpp @@ -34,7 +34,7 @@ void Test::testFromDat() coll->setStrength(Collator::PRIMARY); - std::vector<uint16_t> u16buffer(100); + std::vector<UChar> u16buffer(100); std::vector<uint8_t> u8buffer(10); int fd = open("sort-blobs.dat", O_RDONLY); diff --git a/searchlib/src/vespa/searchlib/attribute/enumstorebase.h b/searchlib/src/vespa/searchlib/attribute/enumstorebase.h index 8cb5fe596b7..f74345a8806 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumstorebase.h +++ b/searchlib/src/vespa/searchlib/attribute/enumstorebase.h @@ -288,6 +288,9 @@ public: uint32_t getRemaining() const { return _store.getBufferState(_store.getActiveBufferId(TYPE_ID)).remaining(); } + uint32_t getCapacity() const { + return _store.getBufferState(_store.getActiveBufferId(TYPE_ID)).capacity(); + } MemoryUsage getMemoryUsage() const; MemoryUsage getTreeMemoryUsage() const { return _enumDict->getTreeMemoryUsage(); } diff --git a/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp index b9042ac5f6c..c63f03ed44e 100644 --- a/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp @@ -8,7 +8,6 @@ namespace search { namespace multivalueattribute { -constexpr size_t HUGE_MEMORY_PAGE_SIZE = 2 * 1024 * 1024; constexpr size_t SMALL_MEMORY_PAGE_SIZE = 4 * 1024; } @@ -19,7 +18,7 @@ MultiValueAttribute(const vespalib::string &baseFileName, const AttributeVector::Config &cfg) : B(baseFileName, cfg), _mvMapping(MultiValueMapping::optimizedConfigForHugePage(1023, - multivalueattribute::HUGE_MEMORY_PAGE_SIZE, + vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, multivalueattribute::SMALL_MEMORY_PAGE_SIZE, 8 * 1024, cfg.getGrowStrategy().getMultiValueAllocGrowFactor()), diff --git a/searchlib/src/vespa/searchlib/datastore/buffer_type.cpp b/searchlib/src/vespa/searchlib/datastore/buffer_type.cpp index 798c930a3e2..06922835733 100644 --- a/searchlib/src/vespa/searchlib/datastore/buffer_type.cpp +++ b/searchlib/src/vespa/searchlib/datastore/buffer_type.cpp @@ -112,7 +112,7 @@ BufferTypeBase::clampMaxClusters(uint32_t maxClusters) } size_t -BufferTypeBase::calcClustersToAlloc(uint32_t bufferId, size_t sizeNeeded, bool resizing) const +BufferTypeBase::calcClustersToAlloc(uint32_t bufferId, size_t elementsNeeded, bool resizing) const { size_t reservedElements = getReservedElements(bufferId); size_t usedElems = (resizing ? 0 : _activeUsedElems); @@ -121,7 +121,7 @@ BufferTypeBase::calcClustersToAlloc(uint32_t bufferId, size_t sizeNeeded, bool r } assert((usedElems % _clusterSize) == 0); size_t usedClusters = usedElems / _clusterSize; - size_t needClusters = (sizeNeeded + (resizing ? usedElems : reservedElements) + _clusterSize - 1) / _clusterSize; + size_t needClusters = (elementsNeeded + (resizing ? usedElems : reservedElements) + _clusterSize - 1) / _clusterSize; size_t growClusters = (usedClusters * _allocGrowFactor); size_t wantClusters = std::max((resizing ? usedClusters : 0u) + growClusters, static_cast<size_t>(_minClusters)); diff --git a/searchlib/src/vespa/searchlib/datastore/buffer_type.h b/searchlib/src/vespa/searchlib/datastore/buffer_type.h index 321100bb811..adeaa7f4f72 100644 --- a/searchlib/src/vespa/searchlib/datastore/buffer_type.h +++ b/searchlib/src/vespa/searchlib/datastore/buffer_type.h @@ -60,12 +60,12 @@ public: /** * Calculate number of clusters to allocate for new buffer. * - * @param sizeNeeded number of elements needed now + * @param elementsNeeded number of elements needed now * @param clusterRefSize number of clusters expressable via reference type * * @return number of clusters to allocate for new buffer */ - virtual size_t calcClustersToAlloc(uint32_t bufferId, size_t sizeNeeded, bool resizing) const; + virtual size_t calcClustersToAlloc(uint32_t bufferId, size_t elementsNeeded, bool resizing) const; void clampMaxClusters(uint32_t maxClusters); diff --git a/searchlib/src/vespa/searchlib/datastore/bufferstate.cpp b/searchlib/src/vespa/searchlib/datastore/bufferstate.cpp index c2e7c9358e0..9bb6fee7b79 100644 --- a/searchlib/src/vespa/searchlib/datastore/bufferstate.cpp +++ b/searchlib/src/vespa/searchlib/datastore/bufferstate.cpp @@ -4,6 +4,7 @@ #include <limits> using vespalib::alloc::Alloc; +using vespalib::alloc::MemoryAllocator; namespace search::datastore { @@ -30,7 +31,7 @@ BufferState::BufferState() _typeId(0), _clusterSize(0), _compacting(false), - _buffer(Alloc::alloc()) + _buffer(Alloc::alloc(0, MemoryAllocator::HUGEPAGE_SIZE)) { } @@ -45,11 +46,50 @@ BufferState::~BufferState() assert(_freeList.empty()); } +namespace { + +struct AllocResult { + size_t elements; + size_t bytes; + AllocResult(size_t elements_, size_t bytes_) : elements(elements_), bytes(bytes_) {} +}; + +size_t +roundUpToMatchAllocator(size_t sz) +{ + if (sz == 0) { + return 0; + } + // We round up the wanted number of bytes to allocate to match + // the underlying allocator to ensure little to no waste of allocated memory. + if (sz < MemoryAllocator::HUGEPAGE_SIZE) { + // Match heap allocator in vespamalloc. + return vespalib::roundUp2inN(sz); + } else { + // Match mmap allocator. + return MemoryAllocator::roundUpToHugePages(sz); + } +} + +AllocResult +calcAllocation(uint32_t bufferId, + BufferTypeBase &typeHandler, + size_t elementsNeeded, + bool resizing) +{ + size_t allocClusters = typeHandler.calcClustersToAlloc(bufferId, elementsNeeded, resizing); + size_t allocElements = allocClusters * typeHandler.getClusterSize(); + size_t allocBytes = roundUpToMatchAllocator(allocElements * typeHandler.elementSize()); + size_t adjustedAllocElements = (allocBytes / typeHandler.elementSize()); + return AllocResult(adjustedAllocElements, allocBytes); +} + +} void BufferState::onActive(uint32_t bufferId, uint32_t typeId, BufferTypeBase *typeHandler, - size_t sizeNeeded, + size_t elementsNeeded, void *&buffer) { assert(buffer == NULL); @@ -69,13 +109,12 @@ BufferState::onActive(uint32_t bufferId, uint32_t typeId, size_t reservedElements = typeHandler->getReservedElements(bufferId); (void) reservedElements; - size_t allocClusters = typeHandler->calcClustersToAlloc(bufferId, sizeNeeded, false); - size_t allocSize = allocClusters * typeHandler->getClusterSize(); - assert(allocSize >= reservedElements + sizeNeeded); - _buffer.create(allocSize * typeHandler->elementSize()).swap(_buffer); + AllocResult alloc = calcAllocation(bufferId, *typeHandler, elementsNeeded, false); + assert(alloc.elements >= reservedElements + elementsNeeded); + _buffer.create(alloc.bytes).swap(_buffer); buffer = _buffer.get(); - assert(buffer != NULL || allocSize == 0u); - _allocElems = allocSize; + assert(buffer != NULL || alloc.elements == 0u); + _allocElems = alloc.elements; _state = ACTIVE; _typeHandler = typeHandler; _typeId = typeId; @@ -227,26 +266,23 @@ BufferState::disableElemHoldList() void BufferState::fallbackResize(uint32_t bufferId, - uint64_t sizeNeeded, + uint64_t elementsNeeded, void *&buffer, Alloc &holdBuffer) { assert(_state == ACTIVE); assert(_typeHandler != NULL); assert(holdBuffer.get() == NULL); - size_t allocClusters = _typeHandler->calcClustersToAlloc(bufferId, - sizeNeeded, - true); - size_t allocSize = allocClusters * _typeHandler->getClusterSize(); - assert(allocSize >= _usedElems + sizeNeeded); - assert(allocSize > _allocElems); - Alloc newBuffer = _buffer.create(allocSize * _typeHandler->elementSize()); + AllocResult alloc = calcAllocation(bufferId, *_typeHandler, elementsNeeded, true); + assert(alloc.elements >= _usedElems + elementsNeeded); + assert(alloc.elements > _allocElems); + Alloc newBuffer = _buffer.create(alloc.bytes); _typeHandler->fallbackCopy(newBuffer.get(), buffer, _usedElems); holdBuffer.swap(_buffer); std::atomic_thread_fence(std::memory_order_release); _buffer = std::move(newBuffer); buffer = _buffer.get(); - _allocElems = allocSize; + _allocElems = alloc.elements; std::atomic_thread_fence(std::memory_order_release); } diff --git a/searchlib/src/vespa/searchlib/datastore/bufferstate.h b/searchlib/src/vespa/searchlib/datastore/bufferstate.h index 5f579c43751..15c8202a525 100644 --- a/searchlib/src/vespa/searchlib/datastore/bufferstate.h +++ b/searchlib/src/vespa/searchlib/datastore/bufferstate.h @@ -73,14 +73,14 @@ public: /** * Transition from FREE to ACTIVE state. * - * @param bufferId Id of buffer to be active. - * @param typeId registered data type for buffer. - * @param typeHandler type handler for registered data type. - * @param sizeNeeded Number of elements needed to be free - * @param buffer start of buffer. + * @param bufferId Id of buffer to be active. + * @param typeId registered data type for buffer. + * @param typeHandler type handler for registered data type. + * @param elementsNeeded Number of elements needed to be free + * @param buffer start of buffer. */ void onActive(uint32_t bufferId, uint32_t typeId, BufferTypeBase *typeHandler, - size_t sizeNeeded, void *&buffer); + size_t elementsNeeded, void *&buffer); /** * Transition from ACTIVE to HOLD state. @@ -151,7 +151,7 @@ public: size_t getExtraHoldBytes() const { return _extraHoldBytes; } bool getCompacting() const { return _compacting; } void setCompacting() { _compacting = true; } - void fallbackResize(uint32_t bufferId, uint64_t sizeNeeded, void *&buffer, Alloc &holdBuffer); + void fallbackResize(uint32_t bufferId, uint64_t elementsNeeded, void *&buffer, Alloc &holdBuffer); bool isActive(uint32_t typeId) const { return ((_state == ACTIVE) && (_typeId == typeId)); diff --git a/searchlib/src/vespa/searchlib/uca/ucaconverter.cpp b/searchlib/src/vespa/searchlib/uca/ucaconverter.cpp index 47d66a94d9e..fffcc782298 100644 --- a/searchlib/src/vespa/searchlib/uca/ucaconverter.cpp +++ b/searchlib/src/vespa/searchlib/uca/ucaconverter.cpp @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "ucaconverter.h" +#include <unicode/ustring.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/text/utf8.h> #include <mutex> diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attributedfw.cpp b/searchsummary/src/vespa/searchsummary/docsummary/attributedfw.cpp index 363898a8c4e..a15e0e0e0c0 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/attributedfw.cpp +++ b/searchsummary/src/vespa/searchsummary/docsummary/attributedfw.cpp @@ -6,7 +6,7 @@ #include <vespa/searchlib/attribute/stringbase.h> #include <vespa/searchlib/attribute/integerbase.h> #include <vespa/searchlib/attribute/floatbase.h> -#include <vespa/searchlib/tensor/tensor_attribute.h> +#include <vespa/searchlib/tensor/i_tensor_attribute.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/serialization/typed_binary_format.h> #include <vespa/vespalib/objects/nbostream.h> @@ -136,8 +136,8 @@ SingleAttrDFW::insertField(uint32_t docid, BasicType::Type t = v.getBasicType(); switch (t) { case BasicType::TENSOR: { - const tensor::TensorAttribute &tv = - static_cast<const tensor::TensorAttribute &>(v); + const tensor::ITensorAttribute &tv = + dynamic_cast<const tensor::ITensorAttribute &>(v); const auto tensor = tv.getTensor(docid); if (tensor) { vespalib::nbostream str; diff --git a/staging_vespalib/src/vespa/vespalib/data/fileheader.cpp b/staging_vespalib/src/vespa/vespalib/data/fileheader.cpp index a9e131b0c63..d166702ab3a 100644 --- a/staging_vespalib/src/vespa/vespalib/data/fileheader.cpp +++ b/staging_vespalib/src/vespa/vespalib/data/fileheader.cpp @@ -16,6 +16,10 @@ const uint32_t GenericHeader::VERSION(1); const GenericHeader::Tag GenericHeader::EMPTY; const size_t ALIGNMENT=0x1000; +GenericHeader::Tag::~Tag() = default; +GenericHeader::Tag::Tag(const Tag &) = default; +GenericHeader::Tag & GenericHeader::Tag::operator=(const Tag &) = default; + GenericHeader::Tag::Tag() : _type(TYPE_EMPTY), _name(""), @@ -156,8 +160,6 @@ GenericHeader::Tag::Tag(const vespalib::string &name, const vespalib::string &va // empty } -GenericHeader::Tag::~Tag() { } - size_t GenericHeader::Tag::getSize() const { diff --git a/staging_vespalib/src/vespa/vespalib/data/fileheader.h b/staging_vespalib/src/vespa/vespalib/data/fileheader.h index e4449a0c36a..ab57f312c4d 100644 --- a/staging_vespalib/src/vespa/vespalib/data/fileheader.h +++ b/staging_vespalib/src/vespa/vespalib/data/fileheader.h @@ -49,6 +49,8 @@ public: public: Tag(); + Tag(const Tag &); + Tag & operator=(const Tag &); Tag(const vespalib::string &name, float val); Tag(const vespalib::string &name, double val); Tag(const vespalib::string &name, int8_t val); diff --git a/standalone-container/src/main/scala/com/yahoo/container/standalone/LocalFileDb.scala b/standalone-container/src/main/scala/com/yahoo/container/standalone/LocalFileDb.scala index 69443c73b3a..6507b4c72f0 100644 --- a/standalone-container/src/main/scala/com/yahoo/container/standalone/LocalFileDb.scala +++ b/standalone-container/src/main/scala/com/yahoo/container/standalone/LocalFileDb.scala @@ -58,6 +58,10 @@ class LocalFileDb(appPath: Path) extends FileAcquirer with FileRegistry { override def export(): util.List[Entry] = { new java.util.ArrayList(fileReferenceToFile.keys.map{ (ref: FileReference) => new Entry(fileReferenceToFile.get(ref).get.getPath, ref)}.asJavaCollection) } + + override def addUri(uri: String): FileReference = { + throw new RuntimeException("addUri(uri: String) is not implemented here."); + } } object LocalFileDb { diff --git a/testutil/pom.xml b/testutil/pom.xml index 00f606860a4..491c144fdb0 100644 --- a/testutil/pom.xml +++ b/testutil/pom.xml @@ -46,6 +46,11 @@ <artifactId>junit</artifactId> <scope>compile</scope> </dependency> + <dependency> + <groupId>com.google.jimfs</groupId> + <artifactId>jimfs</artifactId> + <scope>compile</scope> + </dependency> </dependencies> <build> <plugins> diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TestFileSystem.java b/testutil/src/main/java/com/yahoo/vespa/test/file/TestFileSystem.java index 465cb671a97..1de62297e0a 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/TestFileSystem.java +++ b/testutil/src/main/java/com/yahoo/vespa/test/file/TestFileSystem.java @@ -1,6 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.node.admin.task.util.file; +package com.yahoo.vespa.test.file; import com.google.common.jimfs.Configuration; import com.google.common.jimfs.Feature; @@ -21,4 +21,6 @@ public class TestFileSystem { .build(); return Jimfs.newFileSystem(configuration); } + + private TestFileSystem() { } } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzSslContextBuilder.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzSslContextBuilder.java index 513191d7c83..0c350356986 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzSslContextBuilder.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/tls/AthenzSslContextBuilder.java @@ -14,7 +14,6 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.security.GeneralSecurityException; import java.security.KeyStore; -import java.security.NoSuchAlgorithmException; import java.security.cert.Certificate; /** @@ -67,9 +66,9 @@ public class AthenzSslContextBuilder { try { SSLContext sslContext = SSLContext.getInstance("TLSv1.2"); TrustManager[] trustManagers = - trustStoreSupplier != null ? createTrustManagers(trustStoreSupplier) : getDefaultTrustManagers(); + trustStoreSupplier != null ? createTrustManagers(trustStoreSupplier) : null; KeyManager[] keyManagers = - keyStoreSupplier != null ? createKeyManagers(keyStoreSupplier, keyStorePassword) : getDefaultKeyManagers(); + keyStoreSupplier != null ? createKeyManagers(keyStoreSupplier, keyStorePassword) : null; sslContext.init(keyManagers, trustManagers, null); return sslContext; } catch (GeneralSecurityException e) { @@ -81,34 +80,18 @@ public class AthenzSslContextBuilder { private static TrustManager[] createTrustManagers(KeyStoreSupplier trustStoreSupplier) throws GeneralSecurityException, IOException { - TrustManagerFactory trustManagerFactory = getTrustManagerFactory(); + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init(trustStoreSupplier.get()); return trustManagerFactory.getTrustManagers(); } private static KeyManager[] createKeyManagers(KeyStoreSupplier keyStoreSupplier, char[] password) throws GeneralSecurityException, IOException { - KeyManagerFactory keyManagerFactory = getKeyManagerFactory(); + KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); keyManagerFactory.init(keyStoreSupplier.get(), password); return keyManagerFactory.getKeyManagers(); } - private static KeyManager[] getDefaultKeyManagers() throws NoSuchAlgorithmException { - return getKeyManagerFactory().getKeyManagers(); - } - - private static TrustManager[] getDefaultTrustManagers() throws NoSuchAlgorithmException { - return getTrustManagerFactory().getTrustManagers(); - } - - private static KeyManagerFactory getKeyManagerFactory() throws NoSuchAlgorithmException { - return KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); - } - - private static TrustManagerFactory getTrustManagerFactory() throws NoSuchAlgorithmException { - return TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - } - private static KeyStore loadKeyStoreFromFile(File file, char[] password, String keyStoreType) throws IOException, GeneralSecurityException{ KeyStore keyStore = KeyStore.getInstance(keyStoreType); diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/api/AthenzDomainTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/api/AthenzDomainTest.java index c3fa7396569..2a35fe63d5c 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/api/AthenzDomainTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/api/AthenzDomainTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.api; import org.junit.Test; @@ -53,4 +54,4 @@ public class AthenzDomainTest { } -}
\ No newline at end of file +} diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/AthenzSslContextBuilderTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/AthenzSslContextBuilderTest.java new file mode 100644 index 00000000000..5aca1fc3116 --- /dev/null +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/tls/AthenzSslContextBuilderTest.java @@ -0,0 +1,69 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.tls; + +import com.yahoo.athenz.auth.util.Crypto; +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.operator.OperatorCreationException; +import org.bouncycastle.pkcs.PKCS10CertificationRequest; +import org.junit.Test; + +import java.io.IOException; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.KeyStore; +import java.security.NoSuchAlgorithmException; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; + +/** + * @author bjorncs + */ +public class AthenzSslContextBuilderTest { + + private static final char[] PASSWORD = new char[0]; + + @Test + public void can_build_sslcontext_with_truststore_only() throws Exception { + new AthenzSslContextBuilder() + .withTrustStore(createKeystore()) + .build(); + } + + @Test + public void can_build_sslcontext_with_keystore_only() throws Exception { + new AthenzSslContextBuilder() + .withKeyStore(createKeystore(), PASSWORD) + .build(); + } + + @Test + public void can_build_sslcontext_with_truststore_and_keystore() throws Exception { + new AthenzSslContextBuilder() + .withKeyStore(createKeystore(), PASSWORD) + .withTrustStore(createKeystore()) + .build(); + } + + private static KeyStore createKeystore() throws Exception { + KeyPair keyPair = createKeyPair(); + KeyStore keystore = KeyStore.getInstance("JKS"); + keystore.load(null); + keystore.setKeyEntry("entry-name", keyPair.getPrivate(), PASSWORD, new Certificate[]{createCertificate(keyPair)}); + return keystore; + } + + private static X509Certificate createCertificate(KeyPair keyPair) throws + OperatorCreationException, IOException { + String x500Principal = "CN=mysubject"; + PKCS10CertificationRequest csr = + Crypto.getPKCS10CertRequest( + Crypto.generateX509CSR(keyPair.getPrivate(), x500Principal, null)); + return Crypto.generateX509Certificate(csr, keyPair.getPrivate(), new X500Name(x500Principal), 3600, false); + } + + private static KeyPair createKeyPair() throws NoSuchAlgorithmException { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); + keyGen.initialize(512); + return keyGen.genKeyPair(); + } +} diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentitiesTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentitiesTest.java index 5dcc853da5a..301012ab635 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentitiesTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentitiesTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.utils; import com.yahoo.vespa.athenz.api.AthenzDomain; @@ -20,4 +21,4 @@ public class AthenzIdentitiesTest { assertEquals(expectedIdentity, actualIdentity); } -}
\ No newline at end of file +} diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java index 57f38c3a114..ebbfa232f42 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/utils/AthenzIdentityVerifierTest.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.utils; import com.yahoo.vespa.athenz.api.AthenzIdentity; @@ -82,4 +83,4 @@ public class AthenzIdentityVerifierTest { return sslSession; } -}
\ No newline at end of file +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 15d4dfc1d00..92c1c0307ec 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -33,6 +33,7 @@ public class ScalarFunctions { public static DoubleUnaryOperator acos() { return new Acos(); } public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator exp() { return new Exp(); } + public static DoubleUnaryOperator floor() { return new Floor(); } public static DoubleUnaryOperator relu() { return new Relu(); } public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); } public static DoubleUnaryOperator selu() { return new Selu(); } @@ -126,6 +127,14 @@ public class ScalarFunctions { public String toString() { return "f(a)(exp(a))"; } } + public static class Floor implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.floor(operand); } + @Override + public String toString() { return "f(a)(floor(a))"; } + } + + public static class Relu implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return Math.max(operand, 0); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java index 9643c0a56e7..9e3cd834cb8 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java @@ -1,3 +1,4 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; |