diff options
93 files changed, 2295 insertions, 1307 deletions
diff --git a/application-preprocessor/src/main/java/com/yahoo/application/preprocessor/ApplicationPreprocessor.java b/application-preprocessor/src/main/java/com/yahoo/application/preprocessor/ApplicationPreprocessor.java index 3ee926dd124..1988c8a2fff 100644 --- a/application-preprocessor/src/main/java/com/yahoo/application/preprocessor/ApplicationPreprocessor.java +++ b/application-preprocessor/src/main/java/com/yahoo/application/preprocessor/ApplicationPreprocessor.java @@ -5,12 +5,11 @@ import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.model.application.provider.FilesApplicationPackage; -import com.yahoo.config.provision.*; +import com.yahoo.config.provision.Environment; +import com.yahoo.config.provision.RegionName; +import com.yahoo.config.provision.Zone; import com.yahoo.yolean.Exceptions; -import org.xml.sax.SAXException; -import javax.xml.parsers.ParserConfigurationException; -import javax.xml.transform.TransformerException; import java.io.File; import java.io.IOException; import java.util.Optional; @@ -34,7 +33,7 @@ public class ApplicationPreprocessor { this.region = region; } - public void run() throws IOException, TransformerException, ParserConfigurationException, SAXException { + public void run() throws IOException { DeployLogger logger = new BaseDeployLogger(); FilesApplicationPackage.Builder applicationPackageBuilder = new FilesApplicationPackage.Builder(applicationDir); outputDir.ifPresent(applicationPackageBuilder::preprocessedDir); @@ -44,7 +43,6 @@ public class ApplicationPreprocessor { preprocessed.validateXML(); } - public static void main(String args[]) { int argCount = args.length; if (argCount < 1) { diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java index 6aa8f8bf1a1..1bdedc503bf 100644 --- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java +++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.config.model.application.provider; import com.yahoo.component.Version; @@ -601,28 +601,32 @@ public class FilesApplicationPackage implements ApplicationPackage { return searchDefinitionContents(); } - private void preprocessXML(File destination, File inputXml, Zone zone) throws ParserConfigurationException, TransformerException, SAXException, IOException { - Document document = new XmlPreProcessor(appDir, - inputXml, - metaData.getApplicationId().instance(), - zone.environment(), - zone.region()).run(); - Transformer transformer = TransformerFactory.newInstance().newTransformer(); - try (FileOutputStream outputStream = new FileOutputStream(destination)) { - transformer.transform(new DOMSource(document), new StreamResult(outputStream)); + private void preprocessXML(File destination, File inputXml, Zone zone) throws IOException { + if ( ! inputXml.exists()) return; + try { + Document document = new XmlPreProcessor(appDir, + inputXml, + metaData.getApplicationId().instance(), + zone.environment(), + zone.region()).run(); + Transformer transformer = TransformerFactory.newInstance().newTransformer(); + try (FileOutputStream outputStream = new FileOutputStream(destination)) { + transformer.transform(new DOMSource(document), new StreamResult(outputStream)); + } + } catch (TransformerException |ParserConfigurationException | SAXException e) { + throw new RuntimeException("Error preprocessing " + inputXml.getAbsolutePath() + ": " + e.getMessage(), e); } } @Override - public ApplicationPackage preprocess(Zone zone, DeployLogger logger) throws IOException, TransformerException, ParserConfigurationException, SAXException { + public ApplicationPackage preprocess(Zone zone, DeployLogger logger) throws IOException { IOUtils.recursiveDeleteDir(preprocessedDir); IOUtils.copyDirectory(appDir, preprocessedDir, -1, (dir, name) -> ! name.equals(preprocessed) && ! name.equals(SERVICES) && ! name.equals(HOSTS) && ! name.equals(CONFIG_DEFINITIONS_DIR)); preprocessXML(new File(preprocessedDir, SERVICES), getServicesFile(), zone); - if (getHostsFile().exists()) - preprocessXML(new File(preprocessedDir, HOSTS), getHostsFile(), zone); + preprocessXML(new File(preprocessedDir, HOSTS), getHostsFile(), zone); FilesApplicationPackage preprocessed = FilesApplicationPackage.fromFile(preprocessedDir, includeSourceFiles); preprocessed.copyUserDefsIntoApplication(); return preprocessed; diff --git a/config-application-package/src/test/java/com/yahoo/config/model/application/provider/FilesApplicationPackageTest.java b/config-application-package/src/test/java/com/yahoo/config/model/application/provider/FilesApplicationPackageTest.java index be4d8fdab25..dcc75fff540 100644 --- a/config-application-package/src/test/java/com/yahoo/config/model/application/provider/FilesApplicationPackageTest.java +++ b/config-application-package/src/test/java/com/yahoo/config/model/application/provider/FilesApplicationPackageTest.java @@ -10,10 +10,7 @@ import com.yahoo.io.IOUtils; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import org.xml.sax.SAXException; -import javax.xml.parsers.ParserConfigurationException; -import javax.xml.transform.TransformerException; import java.io.File; import java.io.FileReader; import java.io.IOException; @@ -32,7 +29,7 @@ public class FilesApplicationPackageTest { public TemporaryFolder temporaryFolder = new TemporaryFolder(); @Test - public void testPreprocessing() throws IOException, TransformerException, ParserConfigurationException, SAXException { + public void testPreprocessing() throws IOException { File appDir = temporaryFolder.newFolder(); IOUtils.copyDirectory(new File("src/test/resources/multienvapp"), appDir); assertTrue(new File(appDir, "services.xml").exists()); diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java index b6f030fab52..45867eedb31 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java @@ -250,6 +250,7 @@ public interface ApplicationPackage { * * @return A new application package instance pointing to a new location */ + // TODO: TransformerException, ParserConfigurationException, SAXException can be removed from 'throws' when 7.308 is latest version in use default ApplicationPackage preprocess(Zone zone, DeployLogger logger) throws IOException, TransformerException, ParserConfigurationException, SAXException { throw new UnsupportedOperationException("This application package does not support preprocessing"); diff --git a/config-model/pom.xml b/config-model/pom.xml index 95e79fd09fb..c0751431d03 100644 --- a/config-model/pom.xml +++ b/config-model/pom.xml @@ -46,6 +46,11 @@ <scope>test</scope> </dependency> <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>${protobuf.version}</version> + </dependency> + <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <scope>provided</scope> @@ -498,6 +503,10 @@ <updateReleaseInfo>true</updateReleaseInfo> </configuration> </plugin> + <plugin> + <groupId>com.github.os72</groupId> + <artifactId>protoc-jar-maven-plugin</artifactId> + </plugin> </plugins> </build> </project> diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 4011ce43841..ab42e4d821a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -94,7 +95,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement // are there other cases we would like to resolve globally? } - @Override + @Override public TensorType getType(Reference reference) { // computeIfAbsent without concurrent modification due to resolve adding more resolved entries: @@ -158,6 +159,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); } + // A reference to an ONNX model? + Optional<TensorType> onnxFeatureType = onnxFeatureType(reference); + if (onnxFeatureType.isPresent()) { + return onnxFeatureType.get(); + } + // A reference to a feature which returns a tensor? Optional<TensorType> featureTensorType = tensorFeatureType(reference); if (featureTensorType.isPresent()) { @@ -210,6 +217,26 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return Optional.of(function); } + private Optional<TensorType> onnxFeatureType(Reference reference) { + if ( ! reference.name().equals("onnxModel")) + return Optional.empty(); + + if ( ! featureTypes.containsKey(reference)) { + String configOrFileName = reference.arguments().expressions().get(0).toString(); + + // Look up standardized format as added in RankProfile + String modelConfigName = OnnxModelTransformer.getModelConfigName(reference); + String modelOutput = OnnxModelTransformer.getModelOutput(reference, null); + + reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); + if ( ! featureTypes.containsKey(reference)) { + throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'"); + } + } + + return Optional.of(featureTypes.get(reference)); + } + /** * There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet. * This returns the type of those features if this is a reference to either of them, or empty otherwise. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index c2fb2107604..64338e24a8d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -2,13 +2,16 @@ package com.yahoo.searchdefinition; import com.yahoo.config.FileReference; +import com.yahoo.path.Path; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.AbstractService; +import com.yahoo.vespa.model.ml.OnnxModelInfo; import com.yahoo.vespa.model.utils.FileSender; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.List; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; /** @@ -21,16 +24,12 @@ public class OnnxModel { public enum PathType {FILE, URI}; private final String name; + private PathType pathType = PathType.FILE; private String path = null; private String fileReference = ""; - private List<OnnxNameMapping> inputMap = new ArrayList<>(); - private List<OnnxNameMapping> outputMap = new ArrayList<>(); - - public PathType getPathType() { - return pathType; - } - - private PathType pathType = PathType.FILE; + private OnnxModelInfo modelInfo = null; + private Map<String, String> inputMap = new HashMap<>(); + private Map<String, String> outputMap = new HashMap<>(); public OnnxModel(String name) { this.name = name; @@ -49,21 +48,40 @@ public class OnnxModel { } public void setUri(String uri) { - Objects.requireNonNull(uri, "uri cannot be null"); - this.path = uri; - this.pathType = PathType.URI; + throw new IllegalArgumentException("URI for ONNX models are not currently supported"); + } + + public PathType getPathType() { + return pathType; } public void addInputNameMapping(String onnxName, String vespaName) { + addInputNameMapping(onnxName, vespaName, true); + } + + public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); + if (overwrite || ! inputMap.containsKey(onnxName)) { + inputMap.put(onnxName, vespaName); + } } public void addOutputNameMapping(String onnxName, String vespaName) { + addOutputNameMapping(onnxName, vespaName, true); + } + + public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); + if (overwrite || ! outputMap.containsKey(onnxName)) { + outputMap.put(onnxName, vespaName); + } + } + + public void setModelInfo(OnnxModelInfo modelInfo) { + Objects.requireNonNull(modelInfo, "Onnx model info cannot be null"); + this.modelInfo = modelInfo; } /** Initiate sending of this constant to some services over file distribution */ @@ -76,11 +94,20 @@ public class OnnxModel { public String getName() { return name; } public String getFileName() { return path; } + public Path getFilePath() { return Path.fromString(path); } public String getUri() { return path; } public String getFileReference() { return fileReference; } - public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); } - public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); } + public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } + public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } + + public String getDefaultOutput() { + return modelInfo != null ? modelInfo.getDefaultOutput() : ""; + } + + TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) { + return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty; + } public void validate() { if (path == null || path.isEmpty()) @@ -90,23 +117,10 @@ public class OnnxModel { public String toString() { StringBuilder b = new StringBuilder(); b.append("onnx-model '").append(name) - .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) - .append("' with ref '").append(fileReference) - .append("'"); + .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) + .append("' with ref '").append(fileReference) + .append("'"); return b.toString(); } - public static class OnnxNameMapping { - private String onnxName; - private String vespaName; - - private OnnxNameMapping(String onnxName, String vespaName) { - this.onnxName = onnxName; - this.vespaName = vespaName; - } - public String getOnnxName() { return onnxName; } - public String getVespaName() { return vespaName; } - public void setVespaName(String vespaName) { this.vespaName = vespaName; } - } - } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index d309f48d6df..9b129eb66ce 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -18,6 +18,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.VespaModel; @@ -158,6 +159,10 @@ public class RankProfile implements Cloneable { return search != null ? search.rankingConstants() : model.rankingConstants(); } + private Map<String, OnnxModel> onnxModels() { + return search != null ? search.onnxModels().asMap() : Collections.emptyMap(); + } + private Stream<ImmutableSDField> allFields() { if (search == null) return Stream.empty(); if (allFieldsList == null) { @@ -821,9 +826,49 @@ public class RankProfile implements Cloneable { } } + // Add output types for ONNX models + for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) { + String modelName = entry.getKey(); + OnnxModel model = entry.getValue(); + Arguments args = new Arguments(new ReferenceNode(modelName)); + Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context); + + TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes); + context.setType(new Reference("onnxModel", args, null), defaultOutputType); + + for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) { + TensorType type = model.getTensorType(mapping.getKey(), inputTypes); + context.setType(new Reference("onnxModel", args, mapping.getValue()), type); + } + } return context; } + private Map<String, TensorType> resolveOnnxInputTypes(OnnxModel model, MapEvaluationTypeContext context) { + Map<String, TensorType> inputTypes = new HashMap<>(); + for (String onnxInputName : model.getInputMap().keySet()) { + resolveOnnxInputType(onnxInputName, model, context).ifPresent(type -> inputTypes.put(onnxInputName, type)); + } + return inputTypes; + } + + private Optional<TensorType> resolveOnnxInputType(String onnxInputName, OnnxModel model, MapEvaluationTypeContext context) { + String source = model.getInputMap().get(onnxInputName); + if (source != null) { + // Source is either a simple reference (query/attribute/constant)... + Optional<Reference> reference = Reference.simple(source); + if (reference.isPresent()) { + return Optional.of(context.getType(reference.get())); + } + // ... or a function + ExpressionFunction func = context.getFunction(source); + if (func != null) { + return Optional.of(func.getBody().type(context)); + } + } + return Optional.empty(); // if this context does not contain this input + } + private void addAttributeFeatureTypes(ImmutableSDField field, Map<Reference, TensorType> featureTypes) { Attribute attribute = field.getAttribute(); field.getAttributes().forEach((k, a) -> { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 84442fedc48..22a32c8fd65 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); modelBuilder.name(model.getName()); modelBuilder.fileref(model.getFileReference()); - model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName()))); - model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName()))); + model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); + model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); builder.model(modelBuilder); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 87eaaf0387a..56a5d539906 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set<String> functionNames = rankProfile.getFunctions().keySet(); if (functionNames.isEmpty()) return; for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) { - for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) { - String source = mapping.getVespaName(); + for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) { + String source = mapping.getValue(); if (functionNames.contains(source)) { - mapping.setVespaName("rankingExpression(" + source + ")"); + onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")"); } } } @@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>(); for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) { ReferenceNode referenceNode = i.next(); - ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch()); + ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile); if (referenceNode != replacedNode) { replacedSummaryFeatures.add(replacedNode); i.remove(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index ec517768ea9..d23a8376e7a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - if ( ! feature.getName().equals("onnx")) return feature; + if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature; try { FeatureArguments arguments = asFeatureArguments(feature.getArguments()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java index e1ad003e5bd..69cdae10e47 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -1,20 +1,36 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.ImmutableSearch; import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; +import com.yahoo.vespa.model.ml.ModelName; import java.util.List; /** - * Transforms instances of the onnxModel ranking feature and generates - * ONNX configuration if necessary. + * Transforms ONNX model features of the forms: + * + * onnxModel(config_name) + * onnxModel(config_name).output + * onnxModel("path/to/model") + * onnxModel("path/to/model").output + * onnxModel("path/to/model", "path/to/output") + * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused + * + * To the format expected by the backend: + * + * onnxModel(config_name).output * * @author lesters */ @@ -33,85 +49,92 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { if (context.rankProfile() == null) return feature; if (context.rankProfile().getSearch() == null) return feature; - return transformFeature(feature, context.rankProfile().getSearch()); + return transformFeature(feature, context.rankProfile()); } - public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) { - if (!feature.getName().equals("onnxModel")) return feature; + public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) { + ImmutableSearch search = rankProfile.getSearch(); + final String featureName = feature.getName(); + if ( ! featureName.equals("onnxModel")) return feature; Arguments arguments = feature.getArguments(); if (arguments.isEmpty()) - throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " + - "onnx-model config or a ONNX file."); - if (arguments.expressions().size() > 2) - throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); - - // Validation that the file actually exists is handled when the file is added to file distribution. - // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. - - String modelConfigName; - OnnxModel onnxModel; - if (arguments.expressions().get(0) instanceof ReferenceNode) { - modelConfigName = arguments.expressions().get(0).toString(); - onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { - throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found"); - } - } else if (arguments.expressions().get(0) instanceof ConstantNode) { + throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " + + "onnx-model config or an ONNX file."); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments."); + + // Check that the model configuration "onnx-model" exists. If not defined, it should have been added + // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find + // the actual ONNX file, which can happen if we are restarting or upgrading an application using an + // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store. + + String modelConfigName = getModelConfigName(feature.reference()); + OnnxModel onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { String path = asString(arguments.expressions().get(0)); - modelConfigName = asValidIdentifier(path); - onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { - onnxModel = new OnnxModel(modelConfigName, path); - search.onnxModels().add(onnxModel); - } - } else { - throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'"); + ModelName modelName = new ModelName(null, Path.fromString(path), true); + ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile); + FeatureArguments featureArguments = new FeatureArguments(arguments); + return convertedModel.expression(featureArguments, null); } - String output = null; - if (feature.getOutput() != null) { - output = feature.getOutput(); - if ( ! hasOutputMapping(onnxModel, output)) { - onnxModel.addOutputNameMapping(output, output); + String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput()); + String output = getModelOutput(feature.reference(), defaultOutput); + if (! onnxModel.getOutputMap().containsValue(output)) { + throw new IllegalArgumentException(featureName + " argument '" + output + + "' output not found in model '" + onnxModel.getFileName() + "'"); + } + return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output); + } + + public static String getModelConfigName(Reference reference) { + if (reference.arguments().size() > 0) { + ExpressionNode expr = reference.arguments().expressions().get(0); + if (expr instanceof ReferenceNode) { // refers to onnx-model config + return expr.toString(); } - } else if (arguments.expressions().size() > 1) { - String name = asString(arguments.expressions().get(1)); - output = asValidIdentifier(name); - if ( ! hasOutputMapping(onnxModel, output)) { - onnxModel.addOutputNameMapping(name, output); + if (expr instanceof ConstantNode) { // refers to an file path + return asValidIdentifier(expr); } } + return null; + } - // Replace feature with name of config - ExpressionNode argument = new ReferenceNode(modelConfigName); - return new ReferenceNode("onnxModel", List.of(argument), output); - + public static String getModelOutput(Reference reference, String defaultOutput) { + if (reference.output() != null) { + return reference.output(); + } else if (reference.arguments().expressions().size() == 2) { + return asValidIdentifier(reference.arguments().expressions().get(1)); + } else if (reference.arguments().expressions().size() > 2) { + return asValidIdentifier(reference.arguments().expressions().get(2)); + } + return defaultOutput; } - private static boolean hasOutputMapping(OnnxModel onnxModel, String as) { - return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as)); + public static String stripQuotes(String s) { + if (isNotQuoteSign(s.codePointAt(0))) return s; + if (isNotQuoteSign(s.codePointAt(s.length() - 1))) + throw new IllegalArgumentException("argument [" + s + "] is missing end quote"); + return s.substring(1, s.length()-1); } - private static String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); + public static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); } - private static String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); + private static String asValidIdentifier(ExpressionNode node) { + return asValidIdentifier(asString(node)); } - private static boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; + private static boolean isNotQuoteSign(int c) { + return c != '\'' && c != '"'; } - private static String asValidIdentifier(String str) { - return str.replaceAll("[^\\w\\d\\$@_]", "_"); + public static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java new file mode 100644 index 00000000000..70ad3b255e3 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java @@ -0,0 +1,98 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.vespa.model.container.search.QueryProfiles; +import com.yahoo.vespa.model.ml.OnnxModelInfo; + +import java.util.Map; + +/** + * Processes ONNX ranking features of the form: + * + * onnx("files/model.onnx", "path/to/output:1") + * + * And generates an "onnx-model" configuration as if it was defined in the schema: + * + * onnx-model files_model_onnx { + * file: "files/model.onnx" + * } + * + * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be + * processed after this. + * + * @author lesters + */ +public class OnnxModelConfigGenerator extends Processor { + + public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { + super(search, deployLogger, rankProfileRegistry, queryProfiles); + } + + @Override + public void process(boolean validate, boolean documentsOnly) { + if (documentsOnly) return; + for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { + if (profile.getFirstPhaseRanking() != null) { + process(profile.getFirstPhaseRanking().getRoot()); + } + if (profile.getSecondPhaseRanking() != null) { + process(profile.getSecondPhaseRanking().getRoot()); + } + for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) { + process(function.getValue().function().getBody().getRoot()); + } + for (ReferenceNode feature : profile.getSummaryFeatures()) { + process(feature); + } + } + } + + private void process(ExpressionNode node) { + if (node instanceof ReferenceNode) { + process((ReferenceNode)node); + } else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode) node).children()) { + process(child); + } + } + } + + private void process(ReferenceNode feature) { + if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) { + if (feature.getArguments().size() > 0) { + if (feature.getArguments().expressions().get(0) instanceof ConstantNode) { + ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0); + String path = OnnxModelTransformer.stripQuotes(node.sourceString()); + String modelConfigName = OnnxModelTransformer.asValidIdentifier(path); + + // Only add the configuration if the model can actually be found. + if ( ! OnnxModelInfo.modelExists(path, search.applicationPackage())) { + path = ApplicationPackage.MODELS_DIR.append(path).toString(); + if ( ! OnnxModelInfo.modelExists(path, search.applicationPackage())) { + return; + } + } + + OnnxModel onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + onnxModel = new OnnxModel(modelConfigName, path); + search.onnxModels().add(onnxModel); + } + } + } + } + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java new file mode 100644 index 00000000000..8e92b1980ac --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java @@ -0,0 +1,47 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.vespa.model.container.search.QueryProfiles; +import com.yahoo.vespa.model.ml.OnnxModelInfo; + +/** + * Processes every "onnx-model" element in the schema. Associates model type + * information by retrieving from either the ONNX model file directly or from + * preprocessed information in ZK. Adds missing input and output mappings + * (assigning default names). + * + * Must be processed before RankingExpressingTypeResolver. + * + * @author lesters + */ +public class OnnxModelTypeResolver extends Processor { + + public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { + super(search, deployLogger, rankProfileRegistry, queryProfiles); + } + + @Override + public void process(boolean validate, boolean documentsOnly) { + if (documentsOnly) return; + + for (OnnxModel onnxModel : search.onnxModels().asMap().values()) { + OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), search.applicationPackage()); + + // Add any missing input and output fields that were not specified in the onnx-model configuration + for (String onnxName : onnxModelInfo.getInputs()) { + onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); + } + for (String onnxName : onnxModelInfo.getOutputs()) { + onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); + } + + onnxModel.setModelInfo(onnxModelInfo); + } + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java index e8594c2a87f..1a3ef9e54b4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java @@ -74,6 +74,8 @@ public class Processing { ReferenceFieldsProcessor::new, FastAccessValidator::new, ReservedFunctionNames::new, + OnnxModelConfigGenerator::new, + OnnxModelTypeResolver::new, RankingExpressionTypeResolver::new, // These should be last: IndexingValidation::new, diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index c6c7969e466..00797876395 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -5,16 +5,16 @@ import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.IOUtils; import com.yahoo.log.InvalidLogFormatException; -import java.util.logging.Level; import com.yahoo.log.LogMessage; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.vespa.defaults.Defaults; import com.yahoo.yolean.Exceptions; import com.yahoo.system.ProcessExecuter; import com.yahoo.text.StringUtilities; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.collections.Pair; import com.yahoo.config.ConfigInstance; -import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.config.search.ImportedFieldsConfig; import com.yahoo.vespa.config.search.IndexschemaConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -35,6 +35,7 @@ import java.nio.file.Paths; import java.time.Duration; import java.time.Instant; import java.util.logging.Logger; +import java.util.logging.Level; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -152,12 +153,9 @@ public class RankSetupValidator extends Validator { // Assist verify-ranksetup in finding the actual ONNX model files Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap(); if (models.values().size() > 0) { - ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults - String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); List<String> config = new ArrayList<>(models.values().size() * 2); for (OnnxModel model : models.values()) { - String modelFilename = Paths.get(model.getFileName()).getFileName().toString(); - String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString(); + String modelPath = getFileRepositoryPath(model.getFilePath(), model.getFileReference()); config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference())); config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath)); } @@ -165,6 +163,12 @@ public class RankSetupValidator extends Validator { } } + public static String getFileRepositoryPath(Path path, String fileReference) { + ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults + String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); + return Paths.get(fileRefDir, fileReference, path.getName()).toString(); + } + private static void writeConfig(String dir, String configName, ConfigInstance config) throws IOException { IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 943fcbf6c1d..5ee6ed02e61 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -150,7 +150,7 @@ public class ConvertedModel { */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { ExpressionFunction expression = selectExpression(arguments); - if (sourceModel.isPresent()) // we should verify + if (sourceModel.isPresent() && context != null) // we should verify verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); return expression.getBody().getRoot(); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java new file mode 100644 index 00000000000..7526a8a8595 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -0,0 +1,389 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.tensor.TensorType; +import onnx.Onnx; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Model information (input and output types) for an ONNX model. + * This encapsulates the difference between reading ONNX model information + * - from a file application package, where we can read the ONNX model directly + * - from a ZK application package, where the file is unavailable and models are read from + * generated files stored in file distribution or ZooKeeper. + * + * @author lesters + */ +public class OnnxModelInfo { + + private final String defaultOutput; + private final Map<String, OnnxTypeInfo> inputs; + private final Map<String, OnnxTypeInfo> outputs; + private final Map<String, TensorType> vespaTypes = new HashMap<>(); + + private OnnxModelInfo(Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + this.inputs = Collections.unmodifiableMap(inputs); + this.outputs = Collections.unmodifiableMap(outputs); + this.defaultOutput = defaultOutput; + } + + public Set<String> getInputs() { + return inputs.keySet(); + } + + public Set<String> getOutputs() { + return outputs.keySet(); + } + + public String getDefaultOutput() { + return defaultOutput; + } + + /** + * Return the tensor type for an ONNX model for the given context. + * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output + * type depends on the input types for the given context (rank profile). + */ + public TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) { + OnnxTypeInfo onnxTypeInfo = outputs.get(onnxName); + if (onnxTypeInfo == null) { + throw new IllegalArgumentException("Could not find type for output '" + onnxName + "'"); + } + if (onnxTypeInfo.containsUnknownDimensionSizes()) { + Set<Long> unboundSizes = new HashSet<>(); + Map<String, Long> symbolicSizes = new HashMap<>(); + resolveUnknownDimensionSizes(inputTypes, symbolicSizes, unboundSizes); + return onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes); + } + return vespaTypes.computeIfAbsent(onnxName, v -> onnxTypeInfo.toVespaTensorType()); + } + + private void resolveUnknownDimensionSizes(Map<String, TensorType> inputTypes, + Map<String, Long> symbolicSizes, + Set<Long> unboundSizes) + { + for (Map.Entry<String, OnnxTypeInfo> input : inputs.entrySet()) { + String onnxName = input.getKey(); + OnnxTypeInfo onnxType = input.getValue(); + TensorType vespaType = inputTypes.get(onnxName); + if (vespaType == null || vespaType.dimensions().size() != onnxType.dimensions().size()) { + continue; + } + + for (int i = 0; i < vespaType.dimensions().size(); ++i) { + if (vespaType.dimensions().get(i).size().isEmpty()) { + continue; + } + Long size = vespaType.dimensions().get(i).size().get(); + + // Handle dimensions with size -1 - typically batch dimensions + if (onnxType.dimensions().get(i).getSize() == -1) { + unboundSizes.add(size); + if (unboundSizes.size() > 1) { + throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " + + "for type '" + onnxType + "'"); + } + + // Handle dimensions with symbolic names + } else if (onnxType.dimensions().get(i).hasSymbolicName()) { + String symbolicName = onnxType.dimensions().get(i).getSymbolicName(); + if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { + throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" + + symbolicName + "' for input '" + onnxName + "'"); + } + symbolicSizes.put(symbolicName, size); + } + } + } + } + + static public OnnxModelInfo load(String path, ApplicationPackage app) { + Path pathInApplicationPackage = Path.fromString(path); + if (app.getFile(pathInApplicationPackage).exists()) { + return loadFromFile(pathInApplicationPackage, app); + } + if (app.getFile(generatedModelInfoPath(pathInApplicationPackage)).exists()) { + return loadFromGeneratedInfo(pathInApplicationPackage, app); + } + throw new IllegalArgumentException("Unable to find ONNX model file or generated ONNX info file"); + } + + static public boolean modelExists(String path, ApplicationPackage app) { + Path pathInApplicationPackage = Path.fromString(path); + if (app.getFile(pathInApplicationPackage).exists()) { + return true; + } + if (app.getFile(generatedModelInfoPath(Path.fromString(path))).exists()) { + return true; + } + return false; + } + + static private OnnxModelInfo loadFromFile(Path path, ApplicationPackage app) { + try (InputStream inputStream = app.getFile(path).createInputStream()) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + String json = onnxModelToJson(model); + storeGeneratedInfo(json, path, app); + return jsonToModelInfo(json); + } catch (IOException e) { + throw new IllegalArgumentException("Unable to parse ONNX model", e); + } + } + + static private OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage app) { + try { + String json = readGeneratedInfo(path, app); + return jsonToModelInfo(json); + } catch (IOException e) { + throw new IllegalArgumentException("Unable to parse ONNX model", e); + } + } + + static private String readGeneratedInfo(Path path, ApplicationPackage app) throws IOException { + ApplicationFile file = app.getFile(generatedModelInfoPath(path)); + return IOUtils.readAll(file.createReader()); + } + + static private void storeGeneratedInfo(String json, Path path, ApplicationPackage app) throws IOException { + IOUtils.writeFile(app.getFileReference(generatedModelInfoPath(path)), json, false); + } + + static private Path generatedModelInfoPath(Path path) { + String fileName = asValidIdentifier(path.getRelative()) + ".modelinfo.json"; + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); + } + + static private String onnxModelToJson(Onnx.ModelProto model) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); + g.writeStartObject(); + + g.writeArrayFieldStart("inputs"); + for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { + onnxTypeToJson(g, valueInfo); + } + g.writeEndArray(); + + g.writeArrayFieldStart("outputs"); + for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) { + onnxTypeToJson(g, valueInfo); + } + g.writeEndArray(); + + g.writeEndObject(); + g.close(); + return out.toString(); + } + + static public OnnxModelInfo jsonToModelInfo(String json) throws IOException { + ObjectMapper m = new ObjectMapper(); + JsonNode root = m.readTree(json); + Map<String, OnnxTypeInfo> inputs = new HashMap<>(); + Map<String, OnnxTypeInfo> outputs = new HashMap<>(); + String defaultOutput = ""; + + for (JsonNode input : root.get("inputs")) { + inputs.put(input.get("name").textValue(), jsonToTypeInfo(input)); + } + for (JsonNode output : root.get("outputs")) { + outputs.put(output.get("name").textValue(), jsonToTypeInfo(output)); + } + if (root.get("outputs").has(0)) { + defaultOutput = root.get("outputs").get(0).get("name").textValue(); + } + return new OnnxModelInfo(inputs, outputs, defaultOutput); + } + + static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { + g.writeStartObject(); + g.writeStringField("name", valueInfo.getName()); + g.writeStringField("type", onnxValueTypeToString(valueInfo.getType().getTensorType().getElemType())); + g.writeArrayFieldStart("dim"); + for (Onnx.TensorShapeProto.Dimension dim : valueInfo.getType().getTensorType().getShape().getDimList()) { + g.writeStartObject(); + if (dim.hasDimParam()) { + g.writeStringField("type", "param"); + g.writeStringField("size", dim.getDimParam()); + } else { + g.writeStringField("type", "value"); + g.writeNumberField("size", dim.getDimValue()); + } + g.writeEndObject(); + } + g.writeEndArray(); + g.writeEndObject(); + } + + static private OnnxTypeInfo jsonToTypeInfo(JsonNode node) { + TensorType.Value valueType = stringToValueType(node.get("type").textValue()); + OnnxTypeInfo type = new OnnxTypeInfo(valueType); + for (JsonNode dim : node.get("dim")) { + if (dim.get("type").textValue().equals("param")) { + type.addDimension(dim.get("size").textValue()); + } else { + type.addDimension(dim.get("size").longValue()); + } + } + return type; + } + + private static String onnxValueTypeToString(Onnx.TensorProto.DataType dataType) { + switch (dataType) { + case FLOAT: return "float"; + case DOUBLE: return "double"; + // Imperfect conversion, for now: + case BOOL: return "float"; + case INT8: return "float"; + case INT16: return "float"; + case INT32: return "float"; + case INT64: return "float"; + case UINT8: return "float"; + case UINT16: return "float"; + case UINT32: return "float"; + case UINT64: return "float"; + default: + throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + + private static TensorType.Value stringToValueType(String type) { + switch (type) { + case "float": return TensorType.Value.FLOAT; + case "double": return TensorType.Value.DOUBLE; + default: + throw new IllegalArgumentException("Unknown tensor value type: " + type); + } + } + + public static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); + } + + + private static class OnnxTypeInfo { + private final TensorType.Value valueType; + private final List<OnnxDimensionInfo> dimensions = new ArrayList<>(); + + OnnxTypeInfo(TensorType.Value valueType) { + this.valueType = valueType; + } + + void addDimension(long value) { + dimensions.add(new OnnxDimensionInfo(value)); + } + + void addDimension(String param) { + dimensions.add(new OnnxDimensionInfo(param)); + } + + boolean containsUnknownDimensionSizes() { + return dimensions.stream().anyMatch(OnnxDimensionInfo::unknownDimensionSize); + } + + TensorType.Value valueType() { + return valueType; + } + + List<OnnxDimensionInfo> dimensions() { + return dimensions; + } + + TensorType toVespaTensorType() { + return toVespaTensorType(null, null); + } + + TensorType toVespaTensorType(Map<String, Long> symbolicSizes, Set<Long> unboundSizes) { + String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... + TensorType.Builder builder = new TensorType.Builder(valueType); + for (int i = 0; i < dimensions.size(); ++ i) { + String dimensionName = dimensionPrefix + i; + OnnxDimensionInfo onnxDimension = dimensions.get(i); + long onnxDimensionSize = onnxDimension.getSize(); + if (onnxDimension.hasSymbolicName() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getSymbolicName())) { + onnxDimensionSize = symbolicSizes.get(onnxDimension.getSymbolicName()); + } + if (onnxDimensionSize == 0 && symbolicSizes != null) { + // This is for the case where all symbolic dimensions have + // different names, but can be resolved to a single dimension size. + Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values()); + if (unknownSizes.size() == 1) { + onnxDimensionSize = unknownSizes.iterator().next(); + } + } + if (onnxDimensionSize < 0 && unboundSizes != null && unboundSizes.size() > 0) { + onnxDimensionSize = unboundSizes.iterator().next(); + } + if (onnxDimensionSize <= 0) { + return TensorType.empty; // Unable to determine type - probably out of context + } + builder.indexed(dimensionName, onnxDimensionSize); + } + return builder.build(); + } + + @Override + public String toString() { + return "(" + valueType.id() + ")" + + "[" + dimensions.stream().map(OnnxDimensionInfo::toString).collect(Collectors.joining(",")) + "]"; + } + + } + + private static class OnnxDimensionInfo { + private final long size; + private final String symbolicName; + + OnnxDimensionInfo(long size) { + this.size = size; + this.symbolicName = null; + } + + OnnxDimensionInfo(String symbolicName) { + this.size = 0; + this.symbolicName = symbolicName; + } + + long getSize() { + return size; + } + + String getSymbolicName() { + return symbolicName; + } + + boolean hasSymbolicName() { + return symbolicName != null; + } + + boolean unknownDimensionSize() { + return hasSymbolicName() || size <= 0; + } + + @Override + public String toString() { + return hasSymbolicName() ? "\"" + symbolicName + "\"" : Long.toString(size); + } + } + +} diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto new file mode 100644 index 00000000000..dc6542867e0 --- /dev/null +++ b/config-model/src/main/protobuf/onnx.proto @@ -0,0 +1,464 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// Copyright (c) Facebook Inc. and Microsoft Corporation. +// Licensed under the MIT license. + +syntax = "proto2"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. We should use version as + // xx(major) - xx(minor) - xxxx(bugfix) + // and we are starting with 0x00000001 (0.0.1), which was the + // version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x00000001; + + // IR_VERSION 0.0.2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x00000002; + + // IR VERSION 0.0.3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION = 0x00000003; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + } + + // The name field MUST be present for this version of the IR. + optional string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + optional string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + optional AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional float f = 2; // float + optional int64 i = 3; // int + optional bytes s = 4; // UTF-8 string + optional TensorProto t = 5; // tensor value + optional GraphProto g = 6; // graph + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + optional string name = 1; // namespace Value + // This field MUST be present in this version of the IR. + optional TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + optional string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + optional string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + optional string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + optional string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + optional int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 4; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + optional string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + optional string key = 1; + optional string value= 2; +}; + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // also appears in the input list. + repeated TensorProto initializer = 5; + + // A human-readable documentation for this graph. Markdown is allowed. + optional string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // Advanced types + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + optional DataType data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + optional int64 begin = 1; + optional int64 end = 2; + } + optional Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + optional string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + optional string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + optional bytes raw_data = 9; + + // For double + // Complex64 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + optional string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// A set of pre-defined constants to be used as values for +// the standard denotation field in TensorShapeProto.Dimension +// for semantic description of the tensor dimension. +message DenotationConstProto { + // Describe a batch number dimension. + optional string DATA_BATCH = 1 [default = "DATA_BATCH"]; + // Describe a channel dimension. + optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"]; + // Describe a time dimension. + optional string DATA_TIME = 3 [default = "DATA_TIME"]; + // Describe a feature dimension. This is typically a feature + // dimension in RNN and/or spatial dimension in CNN. + optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"]; + // Describe a filter in-channel dimension. This is the dimension + // that is identical (in size) to the channel dimension of the input + // image feature maps. + optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"]; + // Describe a filter out channel dimension. This is the dimension + // that is identical (int size) to the channel dimension of the output + // image feature maps. + optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"]; + // Describe a filter spatial dimension. + optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"]; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST be present for this version of the IR. + optional TensorProto.DataType elem_type = 1; + optional TensorShapeProto shape = 2; + } + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + } +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + optional string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + optional int64 version = 2; +}
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd index e9575af6010..cc73f2daff5 100644 --- a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd +++ b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd @@ -18,7 +18,7 @@ search test { } function mnist_softmax_onnx() { - expression: onnx("mnist_softmax") + expression: onnx_vespa("mnist_softmax") } function my_xgboost() { diff --git a/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py new file mode 100755 index 00000000000..55df3a557e9 --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py @@ -0,0 +1,12 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "sequence"]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "sequence"]) + +nodes = [helper.make_node('Identity', ['input'], ['output'])] +graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT]) +model_def = helper.make_model(graph_def, producer_name='create_dynamic_model.py') +onnx.save(model_def, 'dynamic_model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/create_model.py b/config-model/src/test/integration/onnx-model/files/create_model.py new file mode 100755 index 00000000000..10ff92c2eda --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/create_model.py @@ -0,0 +1,37 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT_1 = helper.make_tensor_value_info('first_input', TensorProto.FLOAT, [2]) +INPUT_2 = helper.make_tensor_value_info('second/input:0', TensorProto.FLOAT, [2]) +INPUT_3 = helper.make_tensor_value_info('third_input', TensorProto.FLOAT, [2]) +OUTPUT_1 = helper.make_tensor_value_info('path/to/output:0', TensorProto.FLOAT, [2]) +OUTPUT_2 = helper.make_tensor_value_info('path/to/output:1', TensorProto.FLOAT, [2]) +OUTPUT_3 = helper.make_tensor_value_info('path/to/output:2', TensorProto.FLOAT, [2]) + +nodes = [ + helper.make_node( + 'Add', + ['first_input', 'second/input:0'], + ['path/to/output:0'], + ), + helper.make_node( + 'Add', + ['third_input', 'second/input:0'], + ['path/to/output:1'] + ), + helper.make_node( + 'Add', + ['path/to/output:0', 'path/to/output:1'], + ['path/to/output:2'] + ), +] +graph_def = helper.make_graph( + nodes, + 'simple_scoring', + [INPUT_1, INPUT_2, INPUT_3], + [OUTPUT_1, OUTPUT_2, OUTPUT_3] +) +model_def = helper.make_model(graph_def, producer_name='create_model.py') +onnx.save(model_def, 'model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/create_unbound_model.py b/config-model/src/test/integration/onnx-model/files/create_unbound_model.py new file mode 100755 index 00000000000..abf733ea43f --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/create_unbound_model.py @@ -0,0 +1,12 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, [-1, 2]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [-1, 2]) + +nodes = [helper.make_node('Identity', ['input'], ['output'])] +graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT]) +model_def = helper.make_model(graph_def, producer_name='create_unbound_model.py') +onnx.save(model_def, 'unbound_model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx new file mode 100644 index 00000000000..6bbdad2d76e --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx @@ -0,0 +1,13 @@ +create_dynamic_model.py:x + +inputoutput"Identitysimple_scoringZ$ +input + +batch + +sequenceb% +output + +batch + +sequenceB
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/files/model.onnx b/config-model/src/test/integration/onnx-model/files/model.onnx new file mode 100644 index 00000000000..f3898205c6a --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/model.onnx @@ -0,0 +1,34 @@ +create_model.py:í +4 +first_input +second/input:0path/to/output:0"Add +4 +third_input +second/input:0path/to/output:1"Add +; +path/to/output:0 +path/to/output:1path/to/output:2"Addsimple_scoringZ +first_input + + +Z +second/input:0 + + +Z +third_input + + +b +path/to/output:0 + + +b +path/to/output:1 + + +b +path/to/output:2 + + +B
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/files/summary_model.onnx b/config-model/src/test/integration/onnx-model/files/summary_model.onnx new file mode 100644 index 00000000000..f3898205c6a --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/summary_model.onnx @@ -0,0 +1,34 @@ +create_model.py:í +4 +first_input +second/input:0path/to/output:0"Add +4 +third_input +second/input:0path/to/output:1"Add +; +path/to/output:0 +path/to/output:1path/to/output:2"Addsimple_scoringZ +first_input + + +Z +second/input:0 + + +Z +third_input + + +b +path/to/output:0 + + +b +path/to/output:1 + + +b +path/to/output:2 + + +B
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/files/unbound_model.onnx b/config-model/src/test/integration/onnx-model/files/unbound_model.onnx new file mode 100644 index 00000000000..155b3125256 --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/unbound_model.onnx @@ -0,0 +1,11 @@ +create_unbound_model.py:p + +inputoutput"Identitysimple_scoringZ +input + +ÿÿÿÿÿÿÿÿÿ +b! +output + +ÿÿÿÿÿÿÿÿÿ +B
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd index 0f0fa694e6f..a87222e77ee 100644 --- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd +++ b/config-model/src/test/integration/onnx-model/schemas/test.sd @@ -14,7 +14,7 @@ search test { } onnx-model my_model { - file: files/ranking_model.onnx + file: files/model.onnx input first_input: attribute(document_field) input "second/input:0": constant(my_constant) input "third_input": my_function @@ -22,19 +22,31 @@ search test { } onnx-model another_model { - file: files/ranking_model.onnx + file: files/model.onnx input first_input: attribute(document_field) input "second/input:0": constant(my_constant) input "third_input": another_function output "path/to/output:2": out } + onnx-model dynamic_model { + file: files/dynamic_model.onnx + input input: my_function + output output: my_output + } + + onnx-model unbound_model { + file: files/unbound_model.onnx + input input: my_function + output output: my_output + } + rank-profile test_model_config { function my_function() { expression: tensor(d0[2])(1) } first-phase { - expression: onnxModel(my_model).out + expression: onnxModel(my_model).out{d0:1} } } @@ -49,7 +61,7 @@ search test { expression: my_function() } first-phase { - expression: onnxModel("files/ranking_model.onnx", "path/to/output:1") + expression: onnxModel("files/model.onnx", "path/to/output:1"){d0:1} } } @@ -62,9 +74,39 @@ search test { } summary-features { onnxModel(another_model).out - onnxModel("files/ranking_model.onnx", "path/to/output:2") + onnxModel("files/summary_model.onnx", "path/to/output:2") } + } + rank-profile test_dynamic_model { + function my_function() { + expression: tensor(d0[1],d1[2])(d1) + } + first-phase { + expression: onnxModel(dynamic_model){d0:0,d1:1} + } } + rank-profile test_dynamic_model_2 { + function my_function_2() { + expression: tensor(d0[1],d1[3])(d1) + } + function my_function() { + expression: my_function_2() + } + first-phase { + expression: onnxModel(dynamic_model){d0:0,d1:2} + } + } + + rank-profile test_unbound_model { + function my_function() { + expression: tensor(d0[1],d1[2])(d1) + } + first-phase { + expression: onnxModel(unbound_model){d0:0,d1:1} + } + } + + } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index d9b0c70dfdd..f8a379b4027 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -1,67 +1,126 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.search.DocumentDatabase; import com.yahoo.vespa.model.search.IndexedSearchCluster; -import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg; +import org.junit.After; import org.junit.Test; import static org.junit.Assert.assertEquals; public class RankingExpressionWithOnnxModelTestCase { + private final Path applicationDir = Path.fromString("src/test/integration/onnx-model/"); + + @After + public void removeGeneratedModelFiles() { + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + @Test - public void testOnnxModelFeature() { - VespaModel model = new VespaModelCreatorWithFilePkg("src/test/integration/onnx-model").create(); - DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0); - assertTransformedFeature(db); - assertGeneratedConfig(db); + public void testOnnxModelFeature() throws Exception { + VespaModel model = loadModel(applicationDir); + assertTransformedFeature(model); + assertGeneratedConfig(model); + + Path storedApplicationDir = applicationDir.append("copy"); + try { + storedApplicationDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedApplicationDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append("schemas").toFile(), storedApplicationDir.append("schemas").toFile()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + + VespaModel storedModel = loadModel(storedApplicationDir); + assertTransformedFeature(storedModel); + assertGeneratedConfig(storedModel); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDir.toFile()); + } } - private void assertGeneratedConfig(DocumentDatabase db) { + private VespaModel loadModel(Path path) throws Exception { + FilesApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(path.toFile()); + DeployState state = new DeployState.Builder().applicationPackage(applicationPackage).build(); + return new VespaModel(state); + } + + private void assertGeneratedConfig(VespaModel model) { + DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0); OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder(); ((OnnxModelsConfig.Producer) db).getConfig(builder); OnnxModelsConfig config = new OnnxModelsConfig(builder); - assertEquals(3, config.model().size()); + assertEquals(6, config.model().size()); - assertEquals("my_model", config.model(1).name()); + assertEquals("my_model", config.model(0).name()); + assertEquals(3, config.model(0).input().size()); + assertEquals("second/input:0", config.model(0).input(0).name()); + assertEquals("constant(my_constant)", config.model(0).input(0).source()); + assertEquals("first_input", config.model(0).input(1).name()); + assertEquals("attribute(document_field)", config.model(0).input(1).source()); + assertEquals("third_input", config.model(0).input(2).name()); + assertEquals("rankingExpression(my_function)", config.model(0).input(2).source()); + assertEquals(3, config.model(0).output().size()); + assertEquals("path/to/output:0", config.model(0).output(0).name()); + assertEquals("out", config.model(0).output(0).as()); + assertEquals("path/to/output:1", config.model(0).output(1).name()); + assertEquals("path_to_output_1", config.model(0).output(1).as()); + assertEquals("path/to/output:2", config.model(0).output(2).name()); + assertEquals("path_to_output_2", config.model(0).output(2).as()); + + assertEquals("files_model_onnx", config.model(1).name()); assertEquals(3, config.model(1).input().size()); - assertEquals("first_input", config.model(1).input(0).name()); - assertEquals("attribute(document_field)", config.model(1).input(0).source()); - assertEquals("second/input:0", config.model(1).input(1).name()); - assertEquals("constant(my_constant)", config.model(1).input(1).source()); - assertEquals("third_input", config.model(1).input(2).name()); - assertEquals("rankingExpression(my_function)", config.model(1).input(2).source()); - assertEquals(1, config.model(1).output().size()); + assertEquals(3, config.model(1).output().size()); assertEquals("path/to/output:0", config.model(1).output(0).name()); - assertEquals("out", config.model(1).output(0).as()); - - assertEquals("files_ranking_model_onnx", config.model(0).name()); - assertEquals(0, config.model(0).input().size()); - assertEquals(2, config.model(0).output().size()); - assertEquals("path/to/output:1", config.model(0).output(0).name()); - assertEquals("path_to_output_1", config.model(0).output(0).as()); - assertEquals("path/to/output:2", config.model(0).output(1).name()); - assertEquals("path_to_output_2", config.model(0).output(1).as()); + assertEquals("path_to_output_0", config.model(1).output(0).as()); + assertEquals("path/to/output:1", config.model(1).output(1).name()); + assertEquals("path_to_output_1", config.model(1).output(1).as()); + assertEquals("path/to/output:2", config.model(1).output(2).name()); + assertEquals("path_to_output_2", config.model(1).output(2).as()); + assertEquals("files_model_onnx", config.model(1).name()); assertEquals("another_model", config.model(2).name()); assertEquals("third_input", config.model(2).input(2).name()); assertEquals("rankingExpression(another_function)", config.model(2).input(2).source()); + + assertEquals("files_summary_model_onnx", config.model(3).name()); + assertEquals(3, config.model(3).input().size()); + assertEquals(3, config.model(3).output().size()); + + assertEquals("dynamic_model", config.model(5).name()); + assertEquals(1, config.model(5).input().size()); + assertEquals(1, config.model(5).output().size()); + assertEquals("rankingExpression(my_function)", config.model(5).input(0).source()); + + assertEquals("unbound_model", config.model(4).name()); + assertEquals(1, config.model(4).input().size()); + assertEquals(1, config.model(4).output().size()); + assertEquals("rankingExpression(my_function)", config.model(4).input(0).source()); + } - private void assertTransformedFeature(DocumentDatabase db) { + private void assertTransformedFeature(VespaModel model) { + DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0); RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); ((RankProfilesConfig.Producer) db).getConfig(builder); RankProfilesConfig config = new RankProfilesConfig(builder); - assertEquals(5, config.rankprofile().size()); + assertEquals(8, config.rankprofile().size()); assertEquals("test_model_config", config.rankprofile(2).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name()); assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name()); - assertEquals("onnxModel(my_model).out", config.rankprofile(2).fef().property(2).value()); + assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name()); + assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value()); assertEquals("test_generated_model_config", config.rankprofile(3).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name()); @@ -69,16 +128,34 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("rankingExpression(second_input).rankingScript", config.rankprofile(3).fef().property(4).name()); assertEquals("rankingExpression(third_input).rankingScript", config.rankprofile(3).fef().property(6).name()); assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name()); - assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_1", config.rankprofile(3).fef().property(8).value()); + assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name()); + assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value()); assertEquals("test_summary_features", config.rankprofile(4).name()); assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name()); assertEquals("1", config.rankprofile(4).fef().property(3).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name()); - assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(4).value()); + assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name()); - assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value()); + assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value()); + + assertEquals("test_dynamic_model", config.rankprofile(5).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name()); + assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value()); + + assertEquals("test_dynamic_model_2", config.rankprofile(6).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name()); + assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); + + assertEquals("test_unbound_model", config.rankprofile(7).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(7).fef().property(3).name()); + assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(7).fef().property(3).value()); + + } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 6bf69907609..40bf970a313 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -70,7 +70,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -87,7 +87,7 @@ public class RankingExpressionWithOnnxTestCase { queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("query(mytensor)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", null, null, "Placeholder", @@ -99,7 +99,7 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithDocumentFeature() { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", null, "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", @@ -117,7 +117,7 @@ public class RankingExpressionWithOnnxTestCase { "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", @@ -129,21 +129,21 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testNestedOnnxReference() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "5 + sum(onnx('mnist_softmax.onnx'))"); + "5 + sum(onnx_vespa('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutput() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'layer_add')"); + "onnx_vespa('mnist_softmax.onnx', 'layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutputAndSignature() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'default.layer_add')"); + "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -155,7 +155,7 @@ public class RankingExpressionWithOnnxTestCase { new QueryProfileRegistry(), " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: onnx('mnist_softmax.onnx')" + + " expression: onnx_vespa('mnist_softmax.onnx')" + " }\n" + " }"); search.compileRankProfile("my_profile", applicationDir.append("models")); @@ -164,7 +164,7 @@ public class RankingExpressionWithOnnxTestCase { } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx('mnist_softmax.onnx'): " + + "onnx_vespa('mnist_softmax.onnx'): " + "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); @@ -175,13 +175,13 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithWrongFunctionType() { try { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)", - "onnx('mnist_softmax.onnx')"); + "onnx_vespa('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx('mnist_softmax.onnx'): " + + "onnx_vespa('mnist_softmax.onnx'): " + "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " + "but this function returns tensor(d0[1],d5[10])", Exceptions.toMessageString(expected)); @@ -192,13 +192,13 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceSpecifyingNonExistingOutput() { try { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'y')"); + "onnx_vespa('mnist_softmax.onnx', 'y')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx('mnist_softmax.onnx','y'): " + + "onnx_vespa('mnist_softmax.onnx','y'): " + "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add", Exceptions.toMessageString(expected)); } @@ -207,7 +207,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testImportingFromStoredExpressions() throws IOException { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx')"); + "onnx_vespa('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir @@ -218,7 +218,7 @@ public class RankingExpressionWithOnnxTestCase { storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); RankProfileSearchFixture searchFromStored = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", null, null, "Placeholder", @@ -243,7 +243,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor<float>(d1[10],d2[784])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx('mnist_softmax.onnx')" + + " expression: onnx_vespa('mnist_softmax.onnx')" + " }\n" + " }" + " rank-profile my_profile_child inherits my_profile {\n" + @@ -288,7 +288,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor<float>(d0[3])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx('" + name + ".onnx')" + + " expression: onnx_vespa('" + name + ".onnx')" + " }\n" + " }"; final String functionName = "imported_ml_function_" + name + "_exp_output"; @@ -310,7 +310,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor<float>(d0[3])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx('" + name + ".onnx')" + + " expression: onnx_vespa('" + name + ".onnx')" + " }\n" + " }" + " rank-profile my_profile_child inherits my_profile {\n" + diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java index a00992da815..8f91a8127bd 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java @@ -221,6 +221,9 @@ public class SystemFlagsDataArchive { } else if (dimension.isEqualTo(DimensionHelper.toWire(FetchVector.Dimension.CONSOLE_USER_EMAIL))) { condition.get("values").forEachArrayElement(conditionValue -> conditionValue.asString() .orElseThrow(() -> new IllegalArgumentException("Non-string email address: " + conditionValue))); + } else if (dimension.isEqualTo(DimensionHelper.toWire(FetchVector.Dimension.TENANT_ID))) { + condition.get("values").forEachArrayElement(conditionValue -> conditionValue.asString() + .orElseThrow(() -> new IllegalArgumentException("Non-string tenant ID: " + conditionValue))); } })); } diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java index aca991ec637..771e42e85f9 100644 --- a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java +++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java @@ -236,6 +236,30 @@ public class SystemFlagsDataArchiveTest { } } + @Test + public void normalize_json_fail_on_invalid_tenant_id() { + try { + SystemFlagsDataArchive.normalizeJson("{\n" + + " \"id\": \"foo\",\n" + + " \"rules\": [\n" + + " {\n" + + " \"conditions\": [\n" + + " {\n" + + " \"type\": \"whitelist\",\n" + + " \"dimension\": \"tenant\",\n" + + " \"values\": [ 123 ]\n" + + " }\n" + + " ],\n" + + " \"value\": true\n" + + " }\n" + + " ]\n" + + "}\n"); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Non-string tenant ID: 123", e.getMessage()); + } + } + private static void assertArchiveReturnsCorrectTestFlagDataForTarget(SystemFlagsDataArchive archive) { assertFlagDataHasValue(archive, MY_TEST_FLAG, mainControllerTarget, "main.controller"); assertFlagDataHasValue(archive, MY_TEST_FLAG, prodUsWestCfgTarget, "main.prod.us-west-1"); diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerEngine.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerEngine.java index bbd622a0d2a..81074c5ea37 100644 --- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerEngine.java +++ b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerEngine.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.dockerapi; import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.command.DockerCmdExecFactory; import com.github.dockerjava.api.command.ExecCreateCmdResponse; import com.github.dockerjava.api.command.InspectContainerResponse; import com.github.dockerjava.api.command.InspectExecResponse; @@ -18,6 +19,7 @@ import com.github.dockerjava.core.DefaultDockerClientConfig; import com.github.dockerjava.core.DockerClientConfig; import com.github.dockerjava.core.DockerClientImpl; import com.github.dockerjava.core.async.ResultCallbackTemplate; +import com.github.dockerjava.core.command.AuthCmdImpl; import com.github.dockerjava.core.command.ExecStartResultCallback; import com.github.dockerjava.core.command.PullImageResultCallback; import com.github.dockerjava.jaxrs.JerseyDockerCmdExecFactory; @@ -61,6 +63,7 @@ public class DockerEngine implements ContainerEngine { private final Set<DockerImage> scheduledPulls = new HashSet<>(); private final DockerClient dockerClient; + private final DockerCmdExecFactory dockerFactory; private final DockerImageGarbageCollector dockerImageGC; private final Metrics metrics; private final Counter numberOfDockerApiFails; @@ -71,8 +74,9 @@ public class DockerEngine implements ContainerEngine { this(createDockerClient(), metrics, Clock.systemUTC()); } - DockerEngine(DockerClient dockerClient, Metrics metrics, Clock clock) { - this.dockerClient = dockerClient; + DockerEngine(DockerClientWithExecFactory clientWithExecFactory, Metrics metrics, Clock clock) { + this.dockerClient = clientWithExecFactory.dockerClient; + this.dockerFactory = clientWithExecFactory.dockerCmdExecFactory; this.dockerImageGC = new DockerImageGarbageCollector(this); this.metrics = metrics; this.clock = clock; @@ -92,11 +96,12 @@ public class DockerEngine implements ContainerEngine { logger.log(Level.INFO, "Starting download of " + image.asString()); if (!registryCredentials.equals(RegistryCredentials.none)) { AuthConfig authConfig = new AuthConfig().withUsername(registryCredentials.username()) - .withPassword(registryCredentials.password()) - .withRegistryAddress(registryCredentials.registryAddress()); - dockerClient.authCmd() - .withAuthConfig(authConfig) - .exec(); + .withPassword(registryCredentials.password()) + .withRegistryAddress(registryCredentials.registryAddress()); + + // Need to create AuthCmdImpl directly since DockerClient.authCmd() will throw + // exception when username/registry url is not set + new AuthCmdImpl(this.dockerFactory.createAuthCmdExec(), authConfig).exec(); } dockerClient.pullImageCmd(image.asString()).exec(new ImagePullCallback(image)); return true; @@ -414,7 +419,7 @@ public class DockerEngine implements ContainerEngine { } } - private static DockerClient createDockerClient() { + private static DockerClientWithExecFactory createDockerClient() { JerseyDockerCmdExecFactory dockerFactory = new JerseyDockerCmdExecFactory() .withMaxPerRouteConnections(10) .withMaxTotalConnections(100) @@ -425,7 +430,18 @@ public class DockerEngine implements ContainerEngine { .withDockerHost("unix:///var/run/docker.sock") .build(); - return DockerClientImpl.getInstance(dockerClientConfig) - .withDockerCmdExecFactory(dockerFactory); + return new DockerClientWithExecFactory( + DockerClientImpl.getInstance(dockerClientConfig).withDockerCmdExecFactory(dockerFactory), + dockerFactory); + } + + static class DockerClientWithExecFactory { + private final DockerClient dockerClient; + private final DockerCmdExecFactory dockerCmdExecFactory; + + public DockerClientWithExecFactory(DockerClient dockerClient, DockerCmdExecFactory dockerCmdExecFactory) { + this.dockerClient = dockerClient; + this.dockerCmdExecFactory = dockerCmdExecFactory; + } } } diff --git a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerEngineTest.java b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerEngineTest.java index 71bdb321305..66bcf89090b 100644 --- a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerEngineTest.java +++ b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerEngineTest.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.dockerapi; import com.github.dockerjava.api.DockerClient; import com.github.dockerjava.api.async.ResultCallback; +import com.github.dockerjava.api.command.DockerCmdExecFactory; import com.github.dockerjava.api.command.ExecCreateCmd; import com.github.dockerjava.api.command.ExecCreateCmdResponse; import com.github.dockerjava.api.command.ExecStartCmd; @@ -42,7 +43,8 @@ public class DockerEngineTest { private final DockerClient dockerClient = mock(DockerClient.class); private final Metrics metrics = new Metrics(); private final ManualClock clock = new ManualClock(); - private final DockerEngine docker = new DockerEngine(dockerClient, metrics, clock); + private final DockerEngine docker = new DockerEngine( + new DockerEngine.DockerClientWithExecFactory(dockerClient, mock(DockerCmdExecFactory.class)), metrics, clock); @Test public void testExecuteCompletes() { diff --git a/eval/src/vespa/eval/eval/value.h b/eval/src/vespa/eval/eval/value.h index 0902a7c1752..e876ba7b472 100644 --- a/eval/src/vespa/eval/eval/value.h +++ b/eval/src/vespa/eval/eval/value.h @@ -103,6 +103,9 @@ public: static const ValueType &shared_type() { return _type; } }; +extern template class ScalarValue<double>; +extern template class ScalarValue<float>; + using DoubleValue = ScalarValue<double>; /** diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml index 4beaf6086a6..181ef6dffbd 100644 --- a/fat-model-dependencies/pom.xml +++ b/fat-model-dependencies/pom.xml @@ -221,5 +221,10 @@ <artifactId>jdisc_http_service</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>${protobuf.version}</version> + </dependency> </dependencies> </project> diff --git a/flags/src/main/java/com/yahoo/vespa/flags/FetchVector.java b/flags/src/main/java/com/yahoo/vespa/flags/FetchVector.java index 89c4f16e27b..37849b65adf 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/FetchVector.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/FetchVector.java @@ -24,6 +24,9 @@ public class FetchVector { * Note: If this enum is changed, you must also change {@link DimensionHelper}. */ public enum Dimension { + /** A legal value for TenantName, e.g. vespa-team */ + TENANT_ID, + /** Value from ApplicationId::serializedForm of the form tenant:applicationName:instance. */ APPLICATION_ID, diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index e622e1dd419..bf3b497bc3f 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -15,6 +15,7 @@ import static com.yahoo.vespa.flags.FetchVector.Dimension.APPLICATION_ID; import static com.yahoo.vespa.flags.FetchVector.Dimension.CONSOLE_USER_EMAIL; import static com.yahoo.vespa.flags.FetchVector.Dimension.HOSTNAME; import static com.yahoo.vespa.flags.FetchVector.Dimension.NODE_TYPE; +import static com.yahoo.vespa.flags.FetchVector.Dimension.TENANT_ID; import static com.yahoo.vespa.flags.FetchVector.Dimension.VESPA_VERSION; import static com.yahoo.vespa.flags.FetchVector.Dimension.ZONE_ID; @@ -266,7 +267,7 @@ public class Flags { "tenant-budget-quota", -1, "The budget in cents/hr a tenant is allowed spend per instance, as calculated by NodeResources", "Only takes effect on next deployment, if set to a value other than the default for flag!", - APPLICATION_ID + TENANT_ID ); public static final UnboundBooleanFlag ONLY_PUBLIC_ACCESS = defineFeatureFlag( diff --git a/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java b/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java index 4b989b8f819..f109d3c950b 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java @@ -20,6 +20,7 @@ public class DimensionHelper { serializedDimensions.put(FetchVector.Dimension.CLUSTER_TYPE, "cluster-type"); serializedDimensions.put(FetchVector.Dimension.VESPA_VERSION, "vespa-version"); serializedDimensions.put(FetchVector.Dimension.CONSOLE_USER_EMAIL, "console-user-email"); + serializedDimensions.put(FetchVector.Dimension.TENANT_ID, "tenant"); if (serializedDimensions.size() != FetchVector.Dimension.values().length) { throw new IllegalStateException(FetchVectorHelper.class.getName() + " is not in sync with " + diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java index b1585922f38..070bf98bf87 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java @@ -49,6 +49,8 @@ public class QuestMetricsDb implements MetricsDb { private final String dataDir; private final CairoEngine engine; + private long highestTimestampAdded = 0; + @Inject public QuestMetricsDb() { this(Defaults.getDefaults().underVespaHome("var/db/vespa/autoscaling"), Clock.systemUTC()); @@ -67,6 +69,7 @@ public class QuestMetricsDb implements MetricsDb { // silence Questdb's custom logging system IOUtils.writeFile(new File(dataDir, "quest-log.conf"), new byte[0]); System.setProperty("questdbLog", dataDir + "/quest-log.conf"); + System.setProperty("org.jooq.no-logo", "true"); CairoConfiguration configuration = new DefaultCairoConfiguration(dataDir); engine = new CairoEngine(configuration); @@ -77,7 +80,9 @@ public class QuestMetricsDb implements MetricsDb { public void add(Collection<Pair<String, MetricSnapshot>> snapshots) { try (TableWriter writer = engine.getWriter(newContext().getCairoSecurityContext(), tableName)) { for (var snapshot : snapshots) { - long atMillis = snapshot.getSecond().at().toEpochMilli(); + long atMillis = adjustIfRecent(snapshot.getSecond().at().toEpochMilli(), highestTimestampAdded); + if (atMillis < highestTimestampAdded) continue; // Ignore old data + highestTimestampAdded = atMillis; TableWriter.Row row = writer.newRow(atMillis * 1000); // in microseconds row.putStr(0, snapshot.getFirst()); row.putFloat(2, (float)snapshot.getSecond().cpu()); @@ -154,6 +159,17 @@ public class QuestMetricsDb implements MetricsDb { } } + private long adjustIfRecent(long timestamp, long highestTimestampAdded) { + if (timestamp >= highestTimestampAdded) return timestamp; + + // We cannot add old data to QuestDb, but we want to use all recent information + long oneMinute = 60 * 1000; + if (timestamp >= highestTimestampAdded - oneMinute) return highestTimestampAdded; + + // Too old; discard + return timestamp; + } + private ListMap<String, MetricSnapshot> getSnapshots(Instant startTime, Set<String> hostnames, SqlCompiler compiler, diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DynamicProvisioningMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DynamicProvisioningMaintainer.java index e5f23a30968..a43b2655ea3 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DynamicProvisioningMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DynamicProvisioningMaintainer.java @@ -24,6 +24,7 @@ import com.yahoo.vespa.hosted.provision.provisioning.NodeResourceComparator; import com.yahoo.vespa.hosted.provision.provisioning.ProvisionedHost; import com.yahoo.yolean.Exceptions; +import javax.naming.NameNotFoundException; import java.time.Duration; import java.util.ArrayList; import java.util.Comparator; @@ -93,7 +94,10 @@ public class DynamicProvisioningMaintainer extends NodeRepositoryMaintainer { nodeRepository().failRecursively( host.hostname(), Agent.operator, "Failed by HostProvisioner due to provisioning failure"); } catch (RuntimeException e) { - log.log(Level.WARNING, "Failed to provision " + host.hostname() + ", will retry in " + interval(), e); + if (e.getCause() instanceof NameNotFoundException) + log.log(Level.INFO, "Failed to provision " + host.hostname() + ", will retry in " + interval() + ": " + e.getMessage()); + else + log.log(Level.WARNING, "Failed to provision " + host.hostname() + ", will retry in " + interval(), e); } }); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/InfraDeployerImpl.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/InfraDeployerImpl.java index 91c83e1c608..d81d16fe62e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/InfraDeployerImpl.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/InfraDeployerImpl.java @@ -20,6 +20,7 @@ import com.yahoo.vespa.hosted.provision.maintenance.InfrastructureVersions; import com.yahoo.vespa.service.monitor.DuperModelInfraApi; import com.yahoo.vespa.service.monitor.InfraApplicationApi; +import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.logging.Level; @@ -53,7 +54,10 @@ public class InfraDeployerImpl implements InfraDeployer { @Override public void activateAllSupportedInfraApplications(boolean propagateException) { - duperModel.getSupportedInfraApplications().forEach(api -> { + duperModel.getSupportedInfraApplications().stream() + // nodes cannot be activated before their host, so try to activate the host first + .sorted(Comparator.comparing(n -> !n.getCapacity().type().isHost())) + .forEach(api -> { var application = api.getApplicationId(); var deployment = new InfraDeployment(api); try { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java index 1e98160955c..68e11c4c995 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java @@ -127,11 +127,7 @@ class NodeAllocation { ++rejectedDueToClashingParentHost; continue; } - if ( ! exclusiveTo(application, candidate.parentHostname())) { - ++rejectedDueToExclusivity; - continue; - } - if ( requestedNodes.isExclusive() && ! hostsOnly(application, candidate.parentHostname())) { + if ( violatesExclusivity(candidate)) { ++rejectedDueToExclusivity; continue; } @@ -158,7 +154,7 @@ class NodeAllocation { if (violatesParentHostPolicy(candidate)) return true; if ( ! hasCompatibleFlavor(candidate)) return true; if (candidate.wantToRetire()) return true; - if (requestedNodes.isExclusive() && ! hostsOnly(application, candidate.parentHostname())) return true; + if (violatesExclusivity(candidate)) return true; return false; } @@ -182,35 +178,23 @@ class NodeAllocation { return false; } - /** - * If a parent host is given, and it hosts another application which requires exclusive access - * to the physical host, then we cannot host this application on it. - */ - private boolean exclusiveTo(ApplicationId applicationId, Optional<String> parentHostname) { - if (parentHostname.isEmpty()) return true; - for (Node nodeOnHost : allNodes.childrenOf(parentHostname.get())) { - if (nodeOnHost.allocation().isEmpty()) continue; - if ( nodeOnHost.allocation().get().membership().cluster().isExclusive() && - ! allocatedTo(applicationId, nodeOnHost)) - return false; - } - return true; - } + private boolean violatesExclusivity(NodeCandidate candidate) { + if (candidate.parentHostname().isEmpty()) return false; - /** Returns true if this host only hosts the given application (in any instance) */ - private boolean hostsOnly(ApplicationId application, Optional<String> parentHostname) { - if (parentHostname.isEmpty()) return true; // yes, as host is exclusive + // In dynamic provisioned zones a node requiring exclusivity must be on a host that has exclusiveTo equal to its owner + if (nodeRepository.zone().getCloud().dynamicProvisioning()) + return requestedNodes.isExclusive() && + ! candidate.parent.flatMap(Node::exclusiveTo).map(application::equals).orElse(false); - for (Node nodeOnHost : allNodes.childrenOf(parentHostname.get())) { + // In non-dynamic provisioned zones we require that if either of the nodes on the host requires exclusivity, + // then all the nodes on the host must have the same owner + for (Node nodeOnHost : allNodes.childrenOf(candidate.parentHostname().get())) { if (nodeOnHost.allocation().isEmpty()) continue; - if ( ! allocatedTo(application, nodeOnHost)) return false; + if (requestedNodes.isExclusive() || nodeOnHost.allocation().get().membership().cluster().isExclusive()) { + if ( ! nodeOnHost.allocation().get().owner().equals(application)) return true; + } } - return true; - } - - private boolean allocatedTo(ApplicationId applicationId, Node node) { - if (node.allocation().isEmpty()) return false; - return node.allocation().get().owner().equals(applicationId); + return false; } /** @@ -390,7 +374,7 @@ class NodeAllocation { /** Prefer to unretire nodes we don't want to retire, and otherwise those with lower index */ private List<NodeCandidate> byUnretiringPriority(Collection<NodeCandidate> candidates) { return candidates.stream() - .sorted(Comparator.comparing((NodeCandidate n) -> n.wantToRetire()) + .sorted(Comparator.comparing(NodeCandidate::wantToRetire) .thenComparing(n -> n.allocation().get().membership().index())) .collect(Collectors.toList()); } @@ -407,7 +391,7 @@ class NodeAllocation { reasons.add("insufficient real resources on hosts"); if (reasons.isEmpty()) return ""; - return ": Not enough nodes available due to " + reasons.stream().collect(Collectors.joining(", ")); + return ": Not enough nodes available due to " + String.join(", ", reasons); } static class FlavorCount { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java index 7ddb5fec3ed..4f5b2fd086e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java @@ -132,7 +132,7 @@ public class NodePrioritizer { if ( !canAllocateNew) return; for (Node host : allNodes) { - if (host.type() == NodeType.host && !nodeRepository.canAllocateTenantNodeTo(host)) continue; + if ( ! nodeRepository.canAllocateTenantNodeTo(host)) continue; if (host.reservedTo().isPresent() && !host.reservedTo().get().equals(application.tenant())) continue; if (host.exclusiveTo().isPresent()) continue; // Never allocate new nodes to exclusive hosts if ( spareHosts.contains(host) && !canAllocateToSpareHosts) continue; @@ -156,7 +156,7 @@ public class NodePrioritizer { .filter(node -> node.allocation().isPresent()) .filter(node -> node.allocation().get().owner().equals(application)) .filter(node -> node.allocation().get().membership().cluster().id().equals(clusterSpec.id())) - .filter(node -> node.state() == Node.State.active || canStillAllocateToParentOf(node)) + .filter(node -> node.state() == Node.State.active || canStillAllocate(node)) .map(node -> candidateFrom(node, false)) .forEach(nodes::add); } @@ -204,8 +204,8 @@ public class NodePrioritizer { * * @return true if we still want to allocate the given node to its parent */ - private boolean canStillAllocateToParentOf(Node node) { - if (node.parentHostname().isEmpty()) return true; + private boolean canStillAllocate(Node node) { + if (node.type() != NodeType.tenant || node.parentHostname().isEmpty()) return true; Optional<Node> parent = allNodes.parentOf(node); if (parent.isEmpty()) return false; return nodeRepository.canAllocateTenantNodeTo(parent.get()); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDbTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDbTest.java index a1cc66ffa28..6d52fb29160 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDbTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDbTest.java @@ -69,6 +69,30 @@ public class QuestMetricsDbTest { } @Test + public void testWriteOldData() { + String dataDir = "data/QuestMetricsDbWriteOldData"; + IOUtils.recursiveDeleteDir(new File(dataDir)); + IOUtils.createDirectory(dataDir + "/metrics"); + ManualClock clock = new ManualClock("2020-10-01T00:00:00"); + QuestMetricsDb db = new QuestMetricsDb(dataDir, clock); + Instant startTime = clock.instant(); + clock.advance(Duration.ofSeconds(300)); + db.add(timeseriesAt(10, clock.instant(), "host1", "host2", "host3")); + clock.advance(Duration.ofSeconds(1)); + + List<NodeTimeseries> nodeTimeSeries1 = db.getNodeTimeseries(startTime, Set.of("host1")); + assertEquals(10, nodeTimeSeries1.get(0).size()); + + db.add(timeseriesAt(10, clock.instant().minus(Duration.ofSeconds(20)), "host1", "host2", "host3")); + List<NodeTimeseries> nodeTimeSeries2 = db.getNodeTimeseries(startTime, Set.of("host1")); + assertEquals("Recent data is accepted", 20, nodeTimeSeries2.get(0).size()); + + db.add(timeseriesAt(10, clock.instant().minus(Duration.ofSeconds(200)), "host1", "host2", "host3")); + List<NodeTimeseries> nodeTimeSeries3 = db.getNodeTimeseries(startTime, Set.of("host1")); + assertEquals("Too old data is rejected", 20, nodeTimeSeries3.get(0).size()); + } + + @Test public void testGc() { String dataDir = "data/QuestMetricsDbGc"; IOUtils.recursiveDeleteDir(new File(dataDir)); @@ -102,4 +126,16 @@ public class QuestMetricsDbTest { return timeseries; } + private Collection<Pair<String, MetricSnapshot>> timeseriesAt(int countPerHost, Instant at, String ... hosts) { + Collection<Pair<String, MetricSnapshot>> timeseries = new ArrayList<>(); + for (int i = 1; i <= countPerHost; i++) { + for (String host : hosts) + timeseries.add(new Pair<>(host, new MetricSnapshot(at, + i * 0.1, + i * 0.2, + i * 0.4, + i % 100))); + } + return timeseries; + } } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisionTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisionTest.java index b871404aa9d..40d0f52dc37 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisionTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisionTest.java @@ -96,6 +96,7 @@ public class DynamicDockerProvisionTest { // Deploy new exclusive application ApplicationId application3 = ProvisioningTester.makeApplicationId(); + mockHostProvisioner(hostProvisioner, "large", 3, application3); prepareAndActivate(application3, clusterSpec("mycluster", true), 4, 1, resources); verify(hostProvisioner).provisionHosts(List.of(104, 105, 106, 107), resources, application3, Version.emptyVersion, HostSharing.exclusive); @@ -159,6 +160,28 @@ public class DynamicDockerProvisionTest { } @Test + public void retires_on_exclusivity_violation() { + ApplicationId application1 = ProvisioningTester.makeApplicationId(); + NodeResources resources = new NodeResources(1, 4, 10, 1); + + mockHostProvisioner(hostProvisioner, "large", 3, null); // Provision shared hosts + prepareAndActivate(application1, clusterSpec("mycluster"), 4, 1, resources); + Set<Node> initialNodes = tester.nodeRepository().list(application1).stream().collect(Collectors.toSet()); + assertEquals(4, initialNodes.size()); + + // Redeploy same application with exclusive=true + mockHostProvisioner(hostProvisioner, "large", 3, application1); + prepareAndActivate(application1, clusterSpec("mycluster", true), 4, 1, resources); + assertEquals(8, tester.nodeRepository().list(application1).size()); + assertEquals(initialNodes, tester.nodeRepository().list(application1).retired().stream().collect(Collectors.toSet())); + + // Redeploy without exclusive again is no-op + prepareAndActivate(application1, clusterSpec("mycluster"), 4, 1, resources); + assertEquals(8, tester.nodeRepository().list(application1).size()); + assertEquals(initialNodes, tester.nodeRepository().list(application1).retired().stream().collect(Collectors.toSet())); + } + + @Test public void node_indices_are_unique_even_when_a_node_is_left_in_reserved_state() { NodeResources resources = new NodeResources(10, 10, 10, 10); ApplicationId app = ProvisioningTester.makeApplicationId(); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java index f22ddfb81c9..9c1ced856cc 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java @@ -15,6 +15,7 @@ import com.yahoo.config.provision.HostSpec; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.OutOfCapacityException; +import com.yahoo.config.provision.ParentHostUnavailableException; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.Zone; @@ -24,9 +25,14 @@ import com.yahoo.vespa.hosted.provision.maintenance.ReservationExpirer; import com.yahoo.vespa.hosted.provision.maintenance.TestMetric; import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.node.History; +import com.yahoo.vespa.hosted.provision.node.IP; +import com.yahoo.vespa.service.duper.ConfigServerApplication; +import com.yahoo.vespa.service.duper.ConfigServerHostApplication; +import com.yahoo.vespa.service.duper.InfraApplication; import org.junit.Test; import java.time.Duration; +import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; @@ -879,6 +885,37 @@ public class ProvisioningTest { } @Test + public void allocates_reserved_nodes_for_type_spec_deployment() { + ProvisioningTester tester = new ProvisioningTester.Builder().build(); + Function<InfraApplication, Collection<HostSpec>> prepareAndActivate = app -> tester.activate(app.getApplicationId(), + tester.prepare(app.getApplicationId(), app.getClusterSpecWithVersion(Version.fromString("1.2.3")), app.getCapacity())); + + // Add 2 config server hosts and 2 config servers + Flavor flavor = tester.nodeRepository().flavors().getFlavorOrThrow("default"); + List<Node> nodes = List.of( + Node.create("cfghost1", new IP.Config(Set.of("::1:0"), Set.of("::1:1")), "cfghost1", flavor, NodeType.confighost).build(), + Node.create("cfghost2", new IP.Config(Set.of("::2:0"), Set.of("::2:1")), "cfghost2", flavor, NodeType.confighost).ipConfig(Set.of("::2:0"), Set.of("::2:1")).build(), + Node.create("cfg1", new IP.Config(Set.of("::1:1"), Set.of()), "cfg1", flavor, NodeType.config).parentHostname("cfghost1").build(), + Node.create("cfg2", new IP.Config(Set.of("::2:1"), Set.of()), "cfg2", flavor, NodeType.config).parentHostname("cfghost2").build()); + tester.nodeRepository().setReady(tester.nodeRepository().addNodes(nodes, Agent.system), Agent.system, ProvisioningTest.class.getSimpleName()); + + InfraApplication cfgHostApp = new ConfigServerHostApplication(); + InfraApplication cfgApp = new ConfigServerApplication(); + + // Attempt to prepare & activate cfg, this should fail as cfg hosts are not active + try { + prepareAndActivate.apply(cfgApp); + } catch (ParentHostUnavailableException ignored) { } + assertEquals(2, tester.nodeRepository().list(cfgApp.getApplicationId()).state(Node.State.reserved).size()); + + prepareAndActivate.apply(cfgHostApp); + + // After activating cfg hosts, we can activate cfgs and all 4 should become active + prepareAndActivate.apply(cfgApp); + assertEquals(4, tester.nodeRepository().list().state(Node.State.active).size()); + } + + @Test public void cluster_spec_update_for_already_reserved_nodes() { ProvisioningTester tester = new ProvisioningTester.Builder().zone(new Zone(Environment.dev, RegionName.from("us-east"))).build(); ApplicationId application = ProvisioningTester.makeApplicationId(); diff --git a/searchcore/CMakeLists.txt b/searchcore/CMakeLists.txt index f98a3c87a2e..3e95c60f21b 100644 --- a/searchcore/CMakeLists.txt +++ b/searchcore/CMakeLists.txt @@ -144,7 +144,6 @@ vespa_define_module( src/tests/proton/server/health_adapter src/tests/proton/server/memory_flush_config_updater src/tests/proton/server/memoryflush - src/tests/proton/server/visibility_handler src/tests/proton/statusreport src/tests/proton/summaryengine src/tests/proton/verify_ranksetup diff --git a/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp b/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp index 3b42a399888..b276ed9e46d 100644 --- a/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp +++ b/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp @@ -86,24 +86,4 @@ TEST("test pendinglidtracker for needcommit") { EXPECT_EQUAL(ILidCommitState::State::COMPLETED, tracker.getState(LIDV_2_1_3)); } -TEST("test two phase pendinglidtracker for needcommit") { - TwoPhasePendingLidTracker tracker; - ILidCommitState::State incomplete = ILidCommitState::State::NEED_COMMIT; - verifyPhase1ProduceAndNeedCommit(tracker, incomplete); - EXPECT_EQUAL(incomplete, tracker.getState()); - EXPECT_EQUAL(incomplete, tracker.getState(LID_1)); - EXPECT_EQUAL(incomplete, tracker.getState(LIDV_2_1_3)); - EXPECT_EQUAL(ILidCommitState::State::COMPLETED, tracker.getState(LIDV_2_3)); - { - ILidCommitState::State waiting = ILidCommitState::State::WAITING; - auto snapshot = tracker.produceSnapshot(); - EXPECT_EQUAL(waiting, tracker.getState()); - EXPECT_EQUAL(waiting, tracker.getState(LID_1)); - EXPECT_EQUAL(waiting, tracker.getState(LIDV_2_1_3)); - } - EXPECT_EQUAL(ILidCommitState::State::COMPLETED, tracker.getState()); - EXPECT_EQUAL(ILidCommitState::State::COMPLETED, tracker.getState(LID_1)); - EXPECT_EQUAL(ILidCommitState::State::COMPLETED, tracker.getState(LIDV_2_1_3)); -} - TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchcore/src/tests/proton/document_iterator/document_iterator_test.cpp b/searchcore/src/tests/proton/document_iterator/document_iterator_test.cpp index 483e6719257..ee22e2668c6 100644 --- a/searchcore/src/tests/proton/document_iterator/document_iterator_test.cpp +++ b/searchcore/src/tests/proton/document_iterator/document_iterator_test.cpp @@ -274,20 +274,6 @@ struct PairDR : DocumentRetrieverBaseForTest { } }; -struct Committer : public ICommitable { - size_t _commitCount; - size_t _commitAndWaitCount; - Committer() : _commitCount(0), _commitAndWaitCount(0) { } - void commit() override { _commitCount++; } - void commitAndWait(ILidCommitState &) override { _commitAndWaitCount++; } - void commitAndWait(ILidCommitState & tracker, uint32_t ) override { - commitAndWait(tracker); - } - void commitAndWait(ILidCommitState & tracker, const std::vector<uint32_t> & ) override { - commitAndWait(tracker); - } -}; - size_t getSize() { return sizeof(DocEntry); } @@ -502,23 +488,57 @@ TEST("require that iterator ignoring maxbytes stops at the end, and does not aut TEST_DO(verifyIterateIgnoringStopSignal(itr)); } -void verifyReadConsistency(DocumentIterator & itr, Committer & committer) { - PendingLidTracker lidTracker; +void verifyReadConsistency(DocumentIterator & itr, ILidCommitState & lidCommitState) { IDocumentRetriever::SP retriever = doc("id:ns:document::1", Timestamp(2), bucket(5)); - auto commitAndWaitRetriever = std::make_shared<CommitAndWaitDocumentRetriever>(retriever, committer, lidTracker); + auto commitAndWaitRetriever = std::make_shared<CommitAndWaitDocumentRetriever>(retriever, lidCommitState); itr.add(commitAndWaitRetriever); IterateResult res = itr.iterate(largeNum); EXPECT_TRUE(res.isCompleted()); EXPECT_EQUAL(1u, res.getEntries().size()); TEST_DO(checkEntry(res, 0, Document(*DataType::DOCUMENT, DocumentId("id:ns:document::1")), Timestamp(2))); - EXPECT_EQUAL(0u, committer._commitCount); } +class ILidCommitStateProxy : public ILidCommitState { +public: + explicit ILidCommitStateProxy(ILidCommitState & lidState) + : _waitCompleteCount(0), + _lidState(lidState) + {} +private: + State waitState(State state, uint32_t lid) const override { + assert(state == State::COMPLETED); + _lidState.waitComplete(lid); + _waitCompleteCount++; + return state; + } + + State waitState(State state, const LidList &lids) const override { + assert(state == State::COMPLETED); + _lidState.waitComplete(lids); + _waitCompleteCount++; + return state; + } + + State waitState(State state) const override { + assert(state == State::COMPLETED); + _lidState.waitComplete(); + _waitCompleteCount++; + return state; + } + +public: + mutable size_t _waitCompleteCount; +private: + ILidCommitState & _lidState; +}; + void verifyStrongReadConsistency(DocumentIterator & itr) { - Committer committer; - TEST_DO(verifyReadConsistency(itr, committer)); - EXPECT_EQUAL(1u, committer._commitAndWaitCount); + PendingLidTracker lidTracker; + + ILidCommitStateProxy lidCommitState(lidTracker); + TEST_DO(verifyReadConsistency(itr, lidCommitState)); + EXPECT_EQUAL(1u, lidCommitState._waitCompleteCount); } TEST("require that default readconsistency does commit") { diff --git a/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp b/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp index b2903f00226..cb7fccea50f 100644 --- a/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp +++ b/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp @@ -7,7 +7,6 @@ #include <vespa/searchcore/proton/attribute/attributemanager.h> #include <vespa/searchcore/proton/attribute/imported_attributes_repo.h> #include <vespa/searchcore/proton/docsummary/summarymanager.h> -#include <vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.h> #include <vespa/searchcore/proton/index/index_writer.h> #include <vespa/searchcore/proton/index/indexmanager.h> #include <vespa/searchcore/proton/reprocessing/attribute_reprocessing_initializer.h> @@ -51,7 +50,6 @@ using Configurer = SearchableDocSubDBConfigurer; using ConfigurerUP = std::unique_ptr<SearchableDocSubDBConfigurer>; using SummarySetup = SummaryManager::SummarySetup; using DocumenttypesConfigSP = proton::DocumentDBConfig::DocumenttypesConfigSP; -using LidReuseDelayerConfig = documentmetastore::LidReuseDelayerConfig; const vespalib::string BASE_DIR("baseDir"); const vespalib::string DOC_TYPE("invalid"); @@ -214,8 +212,7 @@ Fixture::initViewSet(ViewSet &views) views.searchView.get()->getDocumentMetaStore(), *views._gidToLidChangeHandler, views.repo, - views._writeService, - LidReuseDelayerConfig()), + views._writeService), SearchableFeedView::PersistentParams( views.serialNum, views.serialNum, @@ -260,7 +257,7 @@ struct MyFastAccessFeedView _dmsc = make_shared<DocumentMetaStoreContext>(std::make_shared<BucketDBOwner>()); std::shared_ptr<const DocumentTypeRepo> repo = createRepo(); StoreOnlyFeedView::Context storeOnlyCtx(summaryAdapter, schema, _dmsc, *_gidToLidChangeHandler, repo, - _writeService, LidReuseDelayerConfig()); + _writeService); StoreOnlyFeedView::PersistentParams params(1, 1, DocTypeName(DOC_TYPE), 0, SubDbType::NOTREADY); auto mgr = make_shared<AttributeManager>(BASE_DIR, "test.subdb", TuneFileAttributes(), _fileHeaderContext, _writeService.attributeFieldWriter(), _writeService.shared(), _hwInfo); @@ -446,11 +443,7 @@ TEST_F("require that we can reconfigure index searchable", Fixture) } { // verify feed view FeedViewComparer cmp(o.fv, n.fv); - cmp.expect_not_equal(); - cmp.expect_equal_index_adapter(); - cmp.expect_equal_attribute_writer(); - cmp.expect_equal_summary_adapter(); - cmp.expect_equal_schema(); + cmp.expect_equal(); } } @@ -604,11 +597,7 @@ TEST_F("require that we can reconfigure matchers", Fixture) } { // verify feed view FeedViewComparer cmp(o.fv, n.fv); - cmp.expect_not_equal(); - cmp.expect_equal_index_adapter(); - cmp.expect_equal_attribute_writer(); - cmp.expect_equal_summary_adapter(); - cmp.expect_equal_schema(); + cmp.expect_equal(); } } diff --git a/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp b/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp index 754cf4ea15d..06f0ba4109e 100644 --- a/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp +++ b/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp @@ -5,7 +5,6 @@ #include <vespa/searchcore/proton/attribute/imported_attributes_repo.h> #include <vespa/searchcore/proton/bucketdb/bucketdbhandler.h> #include <vespa/searchcore/proton/common/hw_info.h> -#include <vespa/searchcore/proton/common/icommitable.h> #include <vespa/searchcore/proton/initializer/task_runner.h> #include <vespa/searchcore/proton/metrics/attribute_metrics.h> #include <vespa/searchcore/proton/metrics/documentdb_tagged_metrics.h> @@ -253,20 +252,6 @@ struct TwoAttrSchema : public OneAttrSchema } }; -struct Committer : public ICommitable { - size_t _commitCount; - size_t _commitAndWaitCount; - Committer() : _commitCount(0), _commitAndWaitCount(0) { } - void commit() override { _commitCount++; } - void commitAndWait(ILidCommitState & ) override { _commitAndWaitCount++; } - void commitAndWait(ILidCommitState & tracker, uint32_t ) override { - commitAndWait(tracker); - } - void commitAndWait(ILidCommitState & tracker, const std::vector<uint32_t> & ) override { - commitAndWait(tracker); - } -}; - struct MyConfigSnapshot { typedef std::unique_ptr<MyConfigSnapshot> UP; @@ -281,7 +266,7 @@ struct MyConfigSnapshot _bootstrap() { auto documenttypesConfig = std::make_shared<DocumenttypesConfig>(_builder.getDocumenttypesConfig()); - TuneFileDocumentDB::SP tuneFileDocumentDB(new TuneFileDocumentDB()); + auto tuneFileDocumentDB = std::make_shared<TuneFileDocumentDB>(); _bootstrap = std::make_shared<BootstrapConfig>(1, documenttypesConfig, _builder.getDocumentTypeRepo(), diff --git a/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp b/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp index 9bb8865707d..b9683c49c11 100644 --- a/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp +++ b/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp @@ -4,7 +4,6 @@ #include <vespa/searchcore/proton/attribute/ifieldupdatecallback.h> #include <vespa/searchcore/proton/test/bucketfactory.h> #include <vespa/searchcore/proton/common/feedtoken.h> -#include <vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.h> #include <vespa/searchcore/proton/index/i_index_writer.h> #include <vespa/searchcore/proton/server/executorthreadingservice.h> #include <vespa/searchcore/proton/server/isummaryadapter.h> @@ -34,7 +33,6 @@ using document::DocumentId; using document::DocumentUpdate; using proton::matching::SessionManager; using proton::test::MockGidToLidChangeHandler; -using proton::documentmetastore::LidReuseDelayerConfig; using search::AttributeVector; using search::CacheStats; using search::DocumentMetaData; @@ -505,10 +503,9 @@ struct FixtureBase vespalib::ThreadStackExecutor _sharedExecutor; ExecutorThreadingService _writeServiceReal; test::ThreadingServiceObserver _writeService; - vespalib::duration _visibilityDelay; SerialNum serial; std::shared_ptr<MyGidToLidChangeHandler> _gidToLidChangeHandler; - FixtureBase(vespalib::duration visibilityDelay); + FixtureBase(); virtual ~FixtureBase(); @@ -678,7 +675,7 @@ struct FixtureBase }; -FixtureBase::FixtureBase(vespalib::duration visibilityDelay) +FixtureBase::FixtureBase() : _tracer(), sc(), iw(std::make_shared<MyIndexWriter>(_tracer)), @@ -694,7 +691,6 @@ FixtureBase::FixtureBase(vespalib::duration visibilityDelay) _sharedExecutor(1, 0x10000), _writeServiceReal(_sharedExecutor), _writeService(_writeServiceReal), - _visibilityDelay(visibilityDelay), serial(0), _gidToLidChangeHandler(std::make_shared<MyGidToLidChangeHandler>()) { @@ -710,44 +706,51 @@ FixtureBase::populateBeforeCompactLidSpace() { putAndWait(makeDummyDocs(0, 2, 1000)); removeAndWait(makeDummyDocs(1, 1, 2000)); + forceCommitAndWait(); } struct SearchableFeedViewFixture : public FixtureBase { SearchableFeedView fv; - SearchableFeedViewFixture(vespalib::duration visibilityDelay = 0ms) : - FixtureBase(visibilityDelay), + SearchableFeedViewFixture() : + FixtureBase(), fv(StoreOnlyFeedView::Context(sa, sc._schema, _dmsc, *_gidToLidChangeHandler, sc.getRepo(), - _writeService, - LidReuseDelayerConfig(_visibilityDelay, true)), + _writeService), pc.getParams(), FastAccessFeedView::Context(aw, _docIdLimit), SearchableFeedView::Context(iw)) { } + ~SearchableFeedViewFixture() override + { + forceCommitAndWait(); + } IFeedView &getFeedView() override { return fv; } }; struct FastAccessFeedViewFixture : public FixtureBase { FastAccessFeedView fv; - FastAccessFeedViewFixture(vespalib::duration visibilityDelay = vespalib::duration::zero()) : - FixtureBase(visibilityDelay), + FastAccessFeedViewFixture() : + FixtureBase(), fv(StoreOnlyFeedView::Context(sa, sc._schema, _dmsc, *_gidToLidChangeHandler, sc.getRepo(), - _writeService, - LidReuseDelayerConfig(_visibilityDelay, false)), + _writeService), pc.getParams(), FastAccessFeedView::Context(aw, _docIdLimit)) { } + ~FastAccessFeedViewFixture() override + { + forceCommitAndWait(); + } IFeedView &getFeedView() override { return fv; } }; @@ -907,12 +910,14 @@ TEST_F("require that remove() calls removeComplete() via delayed thread service" { EXPECT_TRUE(assertThreadObserver(0, 0, 0, f.writeServiceObserver())); f.putAndWait(f.doc1(10)); + f.forceCommitAndWait(); // put index fields handled in index thread - EXPECT_TRUE(assertThreadObserver(1, 1, 1, f.writeServiceObserver())); + EXPECT_TRUE(assertThreadObserver(2, 2, 2, f.writeServiceObserver())); f.removeAndWait(f.doc1(20)); + f.forceCommitAndWait(); // remove index fields handled in index thread // delayed remove complete handled in same index thread, then master thread - EXPECT_TRUE(assertThreadObserver(3, 2, 2, f.writeServiceObserver())); + EXPECT_TRUE(assertThreadObserver(5, 4, 4, f.writeServiceObserver())); EXPECT_EQUAL(1u, f.metaStoreObserver()._removeCompleteCnt); EXPECT_EQUAL(1u, f.metaStoreObserver()._removeCompleteLid); } @@ -995,18 +1000,25 @@ TEST_F("require that removes are not remembered", SearchableFeedViewFixture) docs.push_back(f.doc("id:test:searchdocument:n=2:2", 14)); f.putAndWait(docs); + f.forceCommitAndWait(); f.removeAndWait(docs[0]); + f.forceCommitAndWait(); f.removeAndWait(docs[3]); + f.forceCommitAndWait(); assertPostConditionAfterRemoves(docs, f); // try to remove again : should have little effect f.removeAndWait(docs[0]); + f.forceCommitAndWait(); f.removeAndWait(docs[3]); + f.forceCommitAndWait(); assertPostConditionAfterRemoves(docs, f); // re-add docs f.putAndWait(docs[3]); + f.forceCommitAndWait(); f.putAndWait(docs[0]); + f.forceCommitAndWait(); EXPECT_EQUAL(5u, f.getMetaStore().getNumUsedLids()); EXPECT_TRUE(f.getMetaData(docs[0]).valid()); EXPECT_TRUE(f.getMetaData(docs[1]).valid()); @@ -1030,7 +1042,9 @@ TEST_F("require that removes are not remembered", SearchableFeedViewFixture) EXPECT_EQUAL(5u, f.msa._store._docs.size()); f.removeAndWait(docs[0]); + f.forceCommitAndWait(); f.removeAndWait(docs[3]); + f.forceCommitAndWait(); EXPECT_EQUAL(3u, f.msa._store._docs.size()); } @@ -1047,11 +1061,13 @@ void putDocumentAndUpdate(Fixture &f, const vespalib::string &fieldName) { DocumentContext dc1 = f.doc1(); f.putAndWait(dc1); + f.forceCommitAndWait(); EXPECT_EQUAL(1u, f.msa._store._lastSyncToken); DocumentContext dc2("id:ns:searchdocument::1", 20, f.getBuilder()); dc2.addFieldUpdate(f.getBuilder(), fieldName); f.updateAndWait(dc2); + f.forceCommitAndWait(); } template <typename Fixture> @@ -1127,11 +1143,11 @@ TEST_F("require that compactLidSpace() propagates to document meta store and doc SearchableFeedViewFixture) { f.populateBeforeCompactLidSpace(); - EXPECT_TRUE(assertThreadObserver(4, 3, 3, f.writeServiceObserver())); + EXPECT_TRUE(assertThreadObserver(5, 4, 4, f.writeServiceObserver())); f.compactLidSpaceAndWait(2); // performIndexForceCommit in index thread, then completion callback // in master thread. - EXPECT_TRUE(assertThreadObserver(6, 5, 5, f.writeServiceObserver())); + EXPECT_TRUE(assertThreadObserver(7, 6, 6, f.writeServiceObserver())); EXPECT_EQUAL(2u, f.metaStoreObserver()._compactLidSpaceLidLimit); EXPECT_EQUAL(2u, f.getDocumentStore()._compactLidSpaceLidLimit); EXPECT_EQUAL(1u, f.metaStoreObserver()._holdUnblockShrinkLidSpaceCnt); @@ -1144,12 +1160,12 @@ TEST_F("require that compactLidSpace() doesn't propagate to " SearchableFeedViewFixture) { f.populateBeforeCompactLidSpace(); - EXPECT_TRUE(assertThreadObserver(4, 3, 3, f.writeServiceObserver())); + EXPECT_TRUE(assertThreadObserver(5, 4, 4, f.writeServiceObserver())); CompactLidSpaceOperation op(0, 2); op.setSerialNum(0); f.runInMaster([&] () { f.fv.handleCompactLidSpace(op); }); // Delayed holdUnblockShrinkLidSpace() in index thread, then master thread - EXPECT_TRUE(assertThreadObserver(5, 4, 3, f.writeServiceObserver())); + EXPECT_TRUE(assertThreadObserver(6, 5, 4, f.writeServiceObserver())); EXPECT_EQUAL(0u, f.metaStoreObserver()._compactLidSpaceLidLimit); EXPECT_EQUAL(0u, f.getDocumentStore()._compactLidSpaceLidLimit); EXPECT_EQUAL(0u, f.metaStoreObserver()._holdUnblockShrinkLidSpaceCnt); @@ -1171,40 +1187,14 @@ TEST_F("require that compactLidSpace() propagates to index writer", EXPECT_EQUAL(2u, f.miw._wantedLidLimit); } -const vespalib::duration LONG_DELAY = 60s; -const vespalib::duration SHORT_DELAY = 500ms; - -TEST_F("require that commit is not called when inside a commit interval", - SearchableFeedViewFixture(LONG_DELAY)) -{ - DocumentContext dc = f.doc1(); - f.putAndWait(dc); - EXPECT_EQUAL(0u, f.miw._commitCount); - EXPECT_EQUAL(0u, f.maw._commitCount); - EXPECT_EQUAL(0u, f._docIdLimit.get()); - f.removeAndWait(dc); - EXPECT_EQUAL(0u, f.miw._commitCount); - EXPECT_EQUAL(0u, f.maw._commitCount); - EXPECT_EQUAL(0u, f._docIdLimit.get()); - f.assertTrace("put(adapter=attribute,serialNum=1,lid=1)," - "put(adapter=index,serialNum=1,lid=1)," - "ack(Result(0, ))," - "remove(adapter=attribute,serialNum=2,lid=1)," - "remove(adapter=index,serialNum=2,lid=1)," - "ack(Result(0, ))"); - f.forceCommitAndWait(); -} - TEST_F("require that commit is not implicitly called", - SearchableFeedViewFixture(SHORT_DELAY)) + SearchableFeedViewFixture) { - std::this_thread::sleep_for(SHORT_DELAY + 100ms); DocumentContext dc = f.doc1(); f.putAndWait(dc); EXPECT_EQUAL(0u, f.miw._commitCount); EXPECT_EQUAL(0u, f.maw._commitCount); EXPECT_EQUAL(0u, f._docIdLimit.get()); - std::this_thread::sleep_for(SHORT_DELAY + 100ms); f.removeAndWait(dc); EXPECT_EQUAL(0u, f.miw._commitCount); EXPECT_EQUAL(0u, f.maw._commitCount); @@ -1219,7 +1209,7 @@ TEST_F("require that commit is not implicitly called", } TEST_F("require that forceCommit updates docid limit", - SearchableFeedViewFixture(LONG_DELAY)) + SearchableFeedViewFixture) { DocumentContext dc = f.doc1(); f.putAndWait(dc); @@ -1237,7 +1227,7 @@ TEST_F("require that forceCommit updates docid limit", "commit(adapter=index,serialNum=1)"); } -TEST_F("require that forceCommit updates docid limit during shrink", SearchableFeedViewFixture(LONG_DELAY)) +TEST_F("require that forceCommit updates docid limit during shrink", SearchableFeedViewFixture) { f.putAndWait(f.makeDummyDocs(0, 3, 1000)); EXPECT_EQUAL(0u, f._docIdLimit.get()); @@ -1262,13 +1252,17 @@ TEST_F("require that move() notifies gid to lid change handler", SearchableFeedV DocumentContext dc1 = f.doc("id::searchdocument::1", 10); DocumentContext dc2 = f.doc("id::searchdocument::2", 20); f.putAndWait(dc1); + f.forceCommitAndWait(); TEST_DO(f.assertChangeHandler(dc1.gid(), 1u, 1u)); f.putAndWait(dc2); + f.forceCommitAndWait(); TEST_DO(f.assertChangeHandler(dc2.gid(), 2u, 2u)); DocumentContext dc3 = f.doc("id::searchdocument::1", 30); f.removeAndWait(dc3); + f.forceCommitAndWait(); TEST_DO(f.assertChangeHandler(dc3.gid(), 0u, 3u)); f.moveAndWait(dc2, 2, 1); + f.forceCommitAndWait(); TEST_DO(f.assertChangeHandler(dc2.gid(), 1u, 4u)); } diff --git a/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp b/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp index 09c4c17220e..cd6b23e5e26 100644 --- a/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp +++ b/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp @@ -9,7 +9,6 @@ #include <vespa/searchcore/proton/attribute/i_attribute_manager.h> #include <vespa/searchcore/proton/bucketdb/bucket_create_notifier.h> #include <vespa/searchcore/proton/common/doctypename.h> -#include <vespa/searchcore/proton/common/feedtoken.h> #include <vespa/searchcore/proton/common/transient_memory_usage_provider.h> #include <vespa/searchcore/proton/documentmetastore/operation_listener.h> #include <vespa/searchcore/proton/feedoperation/moveoperation.h> @@ -354,7 +353,7 @@ struct MockLidSpaceCompactionHandler : public ILidSpaceCompactionHandler }; -class MaintenanceControllerFixture : public ICommitable +class MaintenanceControllerFixture { public: MyExecutor _executor; @@ -385,17 +384,9 @@ public: MaintenanceControllerFixture(); - ~MaintenanceControllerFixture() override; + ~MaintenanceControllerFixture(); void syncSubDBs(); - void commit() override { } - void commitAndWait(ILidCommitState & ) override { } - void commitAndWait(ILidCommitState & tracker, uint32_t ) override { - commitAndWait(tracker); - } - void commitAndWait(ILidCommitState & tracker, const std::vector<uint32_t> & ) override { - commitAndWait(tracker); - } void performSyncSubDBs(); void notifyClusterStateChanged(); void performNotifyClusterStateChanged(); diff --git a/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp b/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp index 3a75f8cd494..c162d7dcd28 100644 --- a/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp +++ b/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp @@ -3,7 +3,6 @@ #include <vespa/document/base/documentid.h> #include <vespa/document/datatype/datatype.h> #include <vespa/searchcommon/common/schema.h> -#include <vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.h> #include <vespa/searchcore/proton/server/executorthreadingservice.h> #include <vespa/searchcore/proton/server/putdonecontext.h> #include <vespa/searchcore/proton/server/removedonecontext.h> @@ -86,7 +85,6 @@ struct MyMinimalFeedView : public MyMinimalFeedViewBase, public StoreOnlyFeedVie MyMinimalFeedView(const ISummaryAdapter::SP &summaryAdapter, const DocumentMetaStore::SP &metaStore, searchcorespi::index::IThreadingService &writeService, - documentmetastore::LidReuseDelayerConfig &lidReuseDelayerConfig, const PersistentParams ¶ms, int &outstandingMoveOps_) : MyMinimalFeedViewBase(), @@ -95,8 +93,7 @@ struct MyMinimalFeedView : public MyMinimalFeedViewBase, public StoreOnlyFeedVie std::make_shared<DocumentMetaStoreContext>(metaStore), *gidToLidChangeHandler, myGetDocumentTypeRepo(), - writeService, - lidReuseDelayerConfig), + writeService), params), removeMultiAttributesCount(0), removeMultiIndexFieldsCount(0), @@ -134,10 +131,9 @@ struct MoveOperationFeedView : public MyMinimalFeedView { MoveOperationFeedView(const ISummaryAdapter::SP &summaryAdapter, const DocumentMetaStore::SP &metaStore, searchcorespi::index::IThreadingService &writeService, - documentmetastore::LidReuseDelayerConfig &lidReuseDelayerConfig, const PersistentParams ¶ms, int &outstandingMoveOps_) : - MyMinimalFeedView(summaryAdapter, metaStore, writeService, lidReuseDelayerConfig, + MyMinimalFeedView(summaryAdapter, metaStore, writeService, params, outstandingMoveOps_), putAttributesCount(0), putIndexFieldsCount(0), @@ -191,8 +187,8 @@ struct FixtureBase { DocumentMetaStore::SP metaStore; vespalib::ThreadStackExecutor sharedExecutor; ExecutorThreadingService writeService; - documentmetastore::LidReuseDelayerConfig lidReuseDelayerConfig; typename FeedViewType::UP feedview; + SerialNum serial_num; explicit FixtureBase(SubDbType subDbType = SubDbType::READY) : removeCount(0), @@ -206,18 +202,18 @@ struct FixtureBase { subDbType)), sharedExecutor(1, 0x10000), writeService(sharedExecutor), - lidReuseDelayerConfig(), - feedview() + feedview(), + serial_num(2u) { StoreOnlyFeedView::PersistentParams params(0, 0, DocTypeName("foo"), subdb_id, subDbType); metaStore->constructFreeList(); ISummaryAdapter::SP adapter = std::make_unique<MySummaryAdapter>(removeCount, putCount, heartbeatCount); - feedview = std::make_unique<FeedViewType>(adapter, metaStore, writeService, lidReuseDelayerConfig, + feedview = std::make_unique<FeedViewType>(adapter, metaStore, writeService, params, outstandingMoveOps); } ~FixtureBase() { - writeService.sync(); + this->force_commit(); } void addSingleDocToMetaStore(uint32_t expected_lid) { @@ -243,6 +239,10 @@ struct FixtureBase { test::runInMaster(writeService, func); } + void force_commit() { + runInMaster([this] () { static_cast<IFeedView&>(*feedview).forceCommit(serial_num); }); + writeService.sync(); + } }; using Fixture = FixtureBase<MyMinimalFeedView>; diff --git a/searchcore/src/tests/proton/documentmetastore/lidreusedelayer/lidreusedelayer_test.cpp b/searchcore/src/tests/proton/documentmetastore/lidreusedelayer/lidreusedelayer_test.cpp index 453f7eb638e..6e4fe34a3c9 100644 --- a/searchcore/src/tests/proton/documentmetastore/lidreusedelayer/lidreusedelayer_test.cpp +++ b/searchcore/src/tests/proton/documentmetastore/lidreusedelayer/lidreusedelayer_test.cpp @@ -128,7 +128,6 @@ class Fixture { public: using LidReuseDelayer = documentmetastore::LidReuseDelayer; - using LidReuseDelayerConfig = documentmetastore::LidReuseDelayerConfig; vespalib::ThreadStackExecutor _sharedExecutor; ExecutorThreadingService _writeServiceReal; test::ThreadingServiceObserver _writeService; @@ -140,7 +139,7 @@ public: _writeServiceReal(_sharedExecutor), _writeService(_writeServiceReal), _store(), - _lidReuseDelayer(std::make_unique<LidReuseDelayer>(_writeService, _store, LidReuseDelayerConfig())) + _lidReuseDelayer(std::make_unique<LidReuseDelayer>(_writeService, _store)) { } @@ -195,15 +194,6 @@ public: return res; } - void - configureLidReuseDelayer(bool immediateCommit, bool hasIndexedOrAttributeFields) { - runInMaster([&] () { - _lidReuseDelayer = std::make_unique<LidReuseDelayer>(_writeService, _store, - LidReuseDelayerConfig(immediateCommit ? vespalib::duration::zero() : 1ms, - hasIndexedOrAttributeFields)); - } ); - } - void commit() { runInMaster([&] () { cycleLids(_lidReuseDelayer->getReuseLids()); }); } @@ -230,88 +220,43 @@ public: TEST_F("require that nothing happens before free list is active", Fixture) { - f.configureLidReuseDelayer(true, true); EXPECT_FALSE(f.delayReuse(4)); EXPECT_FALSE(f.delayReuse({ 5, 6})); EXPECT_TRUE(f._store.assertWork(0, 0, 0)); - EXPECT_TRUE(assertThreadObserver(3, 0, 0, f._writeService)); -} - - -TEST_F("require that single lid is delayed", Fixture) -{ - f._store._freeListActive = true; - f.configureLidReuseDelayer(true, true); - EXPECT_TRUE(f.delayReuse(4)); - f.scheduleDelayReuseLid(4); - EXPECT_TRUE(f._store.assertWork(1, 0, 1)); - EXPECT_TRUE(assertThreadObserver(4, 1, 0, f._writeService)); -} - - -TEST_F("require that lid vector is delayed", Fixture) -{ - f._store._freeListActive = true; - f.configureLidReuseDelayer(true, true); - EXPECT_TRUE(f.delayReuse({ 5, 6, 7})); - f.scheduleDelayReuseLids({ 5, 6, 7}); - EXPECT_TRUE(f._store.assertWork(0, 1, 3)); - EXPECT_TRUE(assertThreadObserver(4, 1, 0, f._writeService)); + EXPECT_TRUE(assertThreadObserver(2, 0, 0, f._writeService)); } TEST_F("require that reuse can be batched", Fixture) { f._store._freeListActive = true; - f.configureLidReuseDelayer(false, true); EXPECT_FALSE(f.delayReuse(4)); EXPECT_FALSE(f.delayReuse({ 5, 6, 7})); EXPECT_TRUE(f._store.assertWork(0, 0, 0)); - EXPECT_TRUE(assertThreadObserver(3, 0, 0, f._writeService)); + EXPECT_TRUE(assertThreadObserver(2, 0, 0, f._writeService)); f.commit(); EXPECT_TRUE(f._store.assertWork(0, 1, 4)); - EXPECT_TRUE(assertThreadObserver(5, 1, 0, f._writeService)); + EXPECT_TRUE(assertThreadObserver(4, 1, 0, f._writeService)); EXPECT_FALSE(f.delayReuse(8)); EXPECT_FALSE(f.delayReuse({ 9, 10})); EXPECT_TRUE(f._store.assertWork(0, 1, 4)); - EXPECT_TRUE(assertThreadObserver(7, 1, 0, f._writeService)); + EXPECT_TRUE(assertThreadObserver(6, 1, 0, f._writeService)); } TEST_F("require that single element array is optimized", Fixture) { f._store._freeListActive = true; - f.configureLidReuseDelayer(false, true); EXPECT_FALSE(f.delayReuse({ 4})); EXPECT_TRUE(f._store.assertWork(0, 0, 0)); - EXPECT_TRUE(assertThreadObserver(2, 0, 0, f._writeService)); + EXPECT_TRUE(assertThreadObserver(1, 0, 0, f._writeService)); f.commit(); - f.configureLidReuseDelayer(true, true); EXPECT_TRUE(f._store.assertWork(1, 0, 1)); - EXPECT_TRUE(assertThreadObserver(5, 1, 0, f._writeService)); - EXPECT_TRUE(f.delayReuse({ 8})); - f.scheduleDelayReuseLids({ 8}); - EXPECT_TRUE(f._store.assertWork(2, 0, 2)); - EXPECT_TRUE(assertThreadObserver(8, 2, 0, f._writeService)); + EXPECT_TRUE(assertThreadObserver(3, 1, 0, f._writeService)); } - -TEST_F("require that lids are reused faster with no indexed fields", Fixture) -{ - f._store._freeListActive = true; - f.configureLidReuseDelayer(true, false); - EXPECT_FALSE(f.delayReuse(4)); - EXPECT_TRUE(f._store.assertWork(1, 0, 1)); - EXPECT_TRUE(assertThreadObserver(2, 0, 0, f._writeService)); - EXPECT_FALSE(f.delayReuse({ 5, 6, 7})); - EXPECT_TRUE(f._store.assertWork(1, 1, 4)); - EXPECT_TRUE(assertThreadObserver(3, 0, 0, f._writeService)); } -} - - - TEST_MAIN() { TEST_RUN_ALL(); diff --git a/searchcore/src/tests/proton/server/visibility_handler/.gitignore b/searchcore/src/tests/proton/server/visibility_handler/.gitignore deleted file mode 100644 index 3666e0c37c3..00000000000 --- a/searchcore/src/tests/proton/server/visibility_handler/.gitignore +++ /dev/null @@ -1 +0,0 @@ -searchcore_visibility_handler_test_app diff --git a/searchcore/src/tests/proton/server/visibility_handler/CMakeLists.txt b/searchcore/src/tests/proton/server/visibility_handler/CMakeLists.txt deleted file mode 100644 index cb79d0fae8a..00000000000 --- a/searchcore/src/tests/proton/server/visibility_handler/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -vespa_add_executable(searchcore_visibility_handler_test_app TEST - SOURCES - visibility_handler_test.cpp - DEPENDS - searchcore_server - searchcore_test -) -vespa_add_test(NAME searchcore_visibility_handler_test_app COMMAND searchcore_visibility_handler_test_app) diff --git a/searchcore/src/tests/proton/server/visibility_handler/visibility_handler_test.cpp b/searchcore/src/tests/proton/server/visibility_handler/visibility_handler_test.cpp deleted file mode 100644 index 2048ecb5458..00000000000 --- a/searchcore/src/tests/proton/server/visibility_handler/visibility_handler_test.cpp +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include <vespa/vespalib/testkit/testapp.h> -#include <vespa/searchcore/proton/server/visibilityhandler.h> -#include <vespa/searchcore/proton/test/dummy_feed_view.h> -#include <vespa/searchcore/proton/test/threading_service_observer.h> -#include <vespa/searchcore/proton/server/executorthreadingservice.h> -#include <vespa/searchcore/proton/common/pendinglidtracker.h> -#include <vespa/vespalib/util/lambdatask.h> - -#include <vespa/log/log.h> -LOG_SETUP("visibility_handler_test"); - -using search::SerialNum; -using proton::IGetSerialNum; -using proton::test::DummyFeedView; -using proton::ExecutorThreadingService; -using proton::test::ThreadingServiceObserver; -using proton::IFeedView; -using proton::VisibilityHandler; -using vespalib::makeLambdaTask; - -namespace { - -class MyGetSerialNum : public IGetSerialNum -{ - SerialNum _serialNum; -public: - MyGetSerialNum() - : _serialNum(0u) - {} - SerialNum getSerialNum() const override { return _serialNum; } - void setSerialNum(SerialNum serialNum) { _serialNum = serialNum; } -}; - - - -class MyFeedView : public DummyFeedView -{ - uint32_t _forceCommitCount; - SerialNum _committedSerialNum; -public: - std::unique_ptr<proton::PendingLidTrackerBase> _tracker; - - - MyFeedView() - : _forceCommitCount(0u), - _committedSerialNum(0u) - {} - - void setTracker(vespalib::duration visibilityDelay) { - if (visibilityDelay == vespalib::duration::zero()) { - _tracker = std::make_unique<proton::PendingLidTracker>(); - } else { - _tracker = std::make_unique<proton::TwoPhasePendingLidTracker>(); - } - } - - void forceCommit(SerialNum serialNum, DoneCallback) override - { - EXPECT_TRUE(serialNum >= _committedSerialNum); - _committedSerialNum = serialNum; - ++_forceCommitCount; - _tracker->produceSnapshot(); - } - - uint32_t getForceCommitCount() const { return _forceCommitCount; } - SerialNum getCommittedSerialNum() const { return _committedSerialNum; } -}; - - -class Fixture -{ -public: - MyGetSerialNum _getSerialNum; - vespalib::ThreadStackExecutor _sharedExecutor; - ExecutorThreadingService _writeServiceReal; - ThreadingServiceObserver _writeService; - std::shared_ptr<MyFeedView> _feedViewReal; - vespalib::VarHolder<IFeedView::SP> _feedView; - VisibilityHandler _visibilityHandler; - - - Fixture() - : _getSerialNum(), - _sharedExecutor(1, 0x10000), - _writeServiceReal(_sharedExecutor), - _writeService(_writeServiceReal), - _feedViewReal(std::make_shared<MyFeedView>()), - _feedView(_feedViewReal), - _visibilityHandler(_getSerialNum, _writeService, _feedView) - {} - - void - checkCommitPostCondition(uint32_t expForceCommitCount, - SerialNum expCommittedSerialNum, - uint32_t expMasterExecuteCnt) - { - EXPECT_EQUAL(expForceCommitCount, _feedViewReal->getForceCommitCount()); - EXPECT_EQUAL(expCommittedSerialNum, - _feedViewReal->getCommittedSerialNum()); - EXPECT_EQUAL(expMasterExecuteCnt, - _writeService.masterObserver().getExecuteCnt()); - } - - void - testCommit(vespalib::duration visibilityDelay, bool internal, - uint32_t expForceCommitCount, SerialNum expCommittedSerialNum, - uint32_t expMasterExecuteCnt, - SerialNum currSerialNum = 10u) - { - _feedViewReal->setTracker(visibilityDelay); - _getSerialNum.setSerialNum(currSerialNum); - _visibilityHandler.setVisibilityDelay(visibilityDelay); - if (internal) { - VisibilityHandler *visibilityHandler = &_visibilityHandler; - auto task = makeLambdaTask([=]() { visibilityHandler->commit(); }); - _writeService.master().execute(std::move(task)); - } else { - _visibilityHandler.commit(); - } - _writeService.master().sync(); - checkCommitPostCondition(expForceCommitCount, - expCommittedSerialNum, - expMasterExecuteCnt); - } - - proton::PendingLidTracker::Token - createToken(proton::PendingLidTrackerBase & tracker, SerialNum serialNum, uint32_t lid) { - if (serialNum == 0) { - return proton::PendingLidTracker::Token(); - } else { - return tracker.produce(lid);; - } - } - - void - testCommitAndWait(vespalib::duration visibilityDelay, bool internal, - uint32_t expForceCommitCount, - SerialNum expCommittedSerialNum, - uint32_t expMasterExecuteCnt, - SerialNum currSerialNum = 10u) - { - _feedViewReal->setTracker(visibilityDelay); - _getSerialNum.setSerialNum(currSerialNum); - _visibilityHandler.setVisibilityDelay(visibilityDelay); - constexpr uint32_t MY_LID=13; - proton::PendingLidTrackerBase * lidTracker = _feedViewReal->_tracker.get(); - { - proton::PendingLidTracker::Token token = createToken(*lidTracker, currSerialNum, MY_LID); - } - if (internal) { - VisibilityHandler *visibilityHandler = &_visibilityHandler; - auto task = makeLambdaTask([=]() { visibilityHandler->commitAndWait(*lidTracker, MY_LID); }); - _writeService.master().execute(std::move(task)); - _writeService.master().sync(); - } else { - _visibilityHandler.commitAndWait(*lidTracker, MY_LID); - } - checkCommitPostCondition(expForceCommitCount, - expCommittedSerialNum, - expMasterExecuteCnt); - } -}; - -} - -TEST_F("Check external commit with zero visibility delay", Fixture) -{ - f.testCommit(0s, false, 0u, 0u, 0u); -} - -TEST_F("Check external commit with nonzero visibility delay", Fixture) -{ - f.testCommit(1s, false, 1u, 10u, 1u); -} - -TEST_F("Check external commit with nonzero visibility delay and no new feed operation", Fixture) -{ - f.testCommit(1s, false, 1u, 0u, 1u, 0u); -} - -TEST_F("Check internal commit with zero visibility delay", Fixture) -{ - f.testCommit(0s, true, 0u, 0u, 1u); -} - -TEST_F("Check internal commit with nonzero visibility delay", Fixture) -{ - f.testCommit(1s, true, 1u, 10u, 1u); -} - -TEST_F("Check internal commit with nonzero visibility delay and no new feed operation", Fixture) -{ - f.testCommit(1s, true, 1u, 0u, 1u, 0u); -} - -TEST_F("Check external commitAndWait with zero visibility delay", Fixture) -{ - f.testCommitAndWait(0s, false, 0u, 0u, 0u); -} - -TEST_F("Check external commitAndWait with nonzero visibility delay", Fixture) -{ - f.testCommitAndWait(1s, false, 1u, 10u, 1u); -} - -TEST_F("Check external commitAndWait with nonzero visibility delay and no new feed operation", Fixture) -{ - f.testCommitAndWait(1s, false, 0u, 0u, 0u, 0u); -} - -TEST_F("Check internal commitAndWait with zero visibility delay", Fixture) -{ - f.testCommitAndWait(0s, true, 0u, 0u, 1u); -} - -TEST_F("Check internal commitAndWait with nonzero visibility delay", Fixture) -{ - f.testCommitAndWait(1s, true, 1u, 10u, 1u); -} - -TEST_F("Check internal commitAndWait with nonzero visibility delay and no new feed operation", Fixture) -{ - f.testCommitAndWait(1s, true, 0u, 0u, 1u, 0u); -} - -TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchcore/src/vespa/searchcore/proton/common/icommitable.h b/searchcore/src/vespa/searchcore/proton/common/icommitable.h deleted file mode 100644 index 55762a69862..00000000000 --- a/searchcore/src/vespa/searchcore/proton/common/icommitable.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#pragma once - -#include <vector> -namespace proton { - -class ILidCommitState; - -/** - * Interface for anyone that needs to commit. - **/ -class ICommitable { -public: - virtual void commit() = 0; - virtual void commitAndWait(ILidCommitState & unCommittedLidTracker) = 0; - virtual void commitAndWait(ILidCommitState &uncommittedLidTracker, uint32_t lid) = 0; - virtual void commitAndWait(ILidCommitState &uncommittedLidTracker, const std::vector<uint32_t> & lid) = 0; -protected: - virtual ~ICommitable() = default; -}; - -} diff --git a/searchcore/src/vespa/searchcore/proton/common/pendinglidtracker.cpp b/searchcore/src/vespa/searchcore/proton/common/pendinglidtracker.cpp index dd6ca70248b..24ddacafaa8 100644 --- a/searchcore/src/vespa/searchcore/proton/common/pendinglidtracker.cpp +++ b/searchcore/src/vespa/searchcore/proton/common/pendinglidtracker.cpp @@ -1,8 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "pendinglidtracker.h" -#include <vespa/vespalib/stllike/hash_map.hpp> -#include <algorithm> #include <cassert> namespace proton { @@ -89,104 +87,4 @@ PendingLidTracker::pendingLids() const { return lids; } -TwoPhasePendingLidTracker::TwoPhasePendingLidTracker() - : _sequenceId(0), - _lastCommitStarted(0), - _lastCommitCompleted(0), - _pending() -{} - -TwoPhasePendingLidTracker::~TwoPhasePendingLidTracker() { - assert(_pending.empty()); -} - -IPendingLidTracker::Token -TwoPhasePendingLidTracker::produce(uint32_t lid) { - std::lock_guard guard(_mutex); - _pending[lid] = ++_sequenceId; - return Token(lid, *this); -} -void -TwoPhasePendingLidTracker::consume(uint32_t lid) { - (void) lid; -} - -ILidCommitState::State -TwoPhasePendingLidTracker::waitFor(MonitorGuard & guard, State state, uint32_t lid) const { - for (auto found = _pending.find(lid); found != _pending.end(); found = _pending.find(lid)) { - if (state == State::NEED_COMMIT) { - if (found->second > _lastCommitStarted) { - return State::NEED_COMMIT; - } - return State::WAITING; - } - _cond.wait(guard); - } - return State::COMPLETED; -} - -void -TwoPhasePendingLidTracker::consumeSnapshot(uint64_t sequenceIdWhenStarted) { - MonitorGuard guard(_mutex); - assert(sequenceIdWhenStarted >= _lastCommitCompleted); - _lastCommitCompleted = sequenceIdWhenStarted; - std::vector<uint32_t> committed; - for (const auto & entry : _pending) { - if (entry.second <= sequenceIdWhenStarted) - committed.push_back(entry.first); - } - for (uint32_t lid : committed) { - _pending.erase(lid); - } - _cond.notify_all(); -} - -ILidCommitState::LidList -TwoPhasePendingLidTracker::pendingLids() const { - MonitorGuard guard(_mutex); - LidList lids; - lids.reserve(_pending.size()); - for (const auto & entry : _pending) { - lids.push_back(entry.first); - } - return lids; -} - -namespace common::internal { - -class CommitList : public PendingLidTrackerBase::Payload { -public: - using LidList = ILidCommitState::LidList; - CommitList(uint64_t commitStarted, TwoPhasePendingLidTracker & tracker) - : _tracker(&tracker), - _commitStarted(commitStarted) - { } - CommitList(const CommitList &) = delete; - CommitList & operator = (const CommitList &) = delete; - CommitList & operator = (CommitList &&) = delete; - CommitList(CommitList && rhs) noexcept - : _tracker(rhs._tracker), - _commitStarted(rhs._commitStarted) - { - rhs._tracker = nullptr; - } - ~CommitList() override { - if (_tracker != nullptr) { - _tracker->consumeSnapshot(_commitStarted); - } - } -private: - TwoPhasePendingLidTracker * _tracker; - uint64_t _commitStarted; -}; - -} - -PendingLidTrackerBase::Snapshot -TwoPhasePendingLidTracker::produceSnapshot() { - MonitorGuard guard(_mutex); - _lastCommitStarted = _sequenceId; - return std::make_unique<common::internal::CommitList>(_lastCommitStarted, *this); -} - } diff --git a/searchcore/src/vespa/searchcore/proton/common/pendinglidtracker.h b/searchcore/src/vespa/searchcore/proton/common/pendinglidtracker.h index ef0a1dbb1a3..079634f56bb 100644 --- a/searchcore/src/vespa/searchcore/proton/common/pendinglidtracker.h +++ b/searchcore/src/vespa/searchcore/proton/common/pendinglidtracker.h @@ -60,30 +60,4 @@ private: vespalib::hash_map<uint32_t, uint32_t> _pending; }; -namespace common::internal { - class CommitList; -} -/** - * Use for tracking lids in 2 phases which is needed when visibility-delay is non-zero. - * It tracks lids that are in feed pipeline, lids where commit has been started and when they fully complete. - */ -class TwoPhasePendingLidTracker : public PendingLidTrackerBase -{ -public: - TwoPhasePendingLidTracker(); - ~TwoPhasePendingLidTracker() override; - Token produce(uint32_t lid) override; - Snapshot produceSnapshot() override; -private: - friend common::internal::CommitList; - void consume(uint32_t lid) override; - void consumeSnapshot(uint64_t sequenceIdWhenStarted); - LidList pendingLids() const override; - State waitFor(MonitorGuard & guard, State state, uint32_t lid) const override; - uint64_t _sequenceId; - uint64_t _lastCommitStarted; - uint64_t _lastCommitCompleted; - vespalib::hash_map<uint32_t, uint64_t> _pending; -}; - } diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/CMakeLists.txt b/searchcore/src/vespa/searchcore/proton/documentmetastore/CMakeLists.txt index d257f29c9df..10e0c149ad2 100644 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/CMakeLists.txt +++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/CMakeLists.txt @@ -12,7 +12,6 @@ vespa_add_library(searchcore_documentmetastore STATIC search_context.cpp lid_allocator.cpp lid_gid_key_comparator.cpp - lid_reuse_delayer_config.cpp lidreusedelayer.cpp lidstatevector.cpp lid_hold_list.cpp diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.cpp b/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.cpp deleted file mode 100644 index b04bac5ef26..00000000000 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.cpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include "lid_reuse_delayer_config.h" -#include <vespa/searchcore/proton/server/documentdbconfig.h> - -namespace proton::documentmetastore { - -LidReuseDelayerConfig::LidReuseDelayerConfig(const DocumentDBConfig & configSnapshot) - : _visibilityDelay(configSnapshot.getMaintenanceConfigSP()->getVisibilityDelay()), - _hasIndexedOrAttributeFields(configSnapshot.getSchemaSP()->getNumIndexFields() > 0 || - configSnapshot.getSchemaSP()->getNumAttributeFields() > 0) -{ -} - -LidReuseDelayerConfig::LidReuseDelayerConfig() - : LidReuseDelayerConfig(vespalib::duration::zero(), false) -{} - -LidReuseDelayerConfig::LidReuseDelayerConfig(vespalib::duration visibilityDelay, bool hasIndexedOrAttributeFields_in) - : _visibilityDelay(visibilityDelay), - _hasIndexedOrAttributeFields(hasIndexedOrAttributeFields_in) -{ -} - -} diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.h b/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.h deleted file mode 100644 index 82dab433a22..00000000000 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#pragma once - -#include <vespa/vespalib/util/time.h> - -namespace proton { class DocumentDBConfig; } - -namespace proton::documentmetastore { - -/* - * Class representing configuration for lid reuse delayer. - */ -class LidReuseDelayerConfig -{ -private: - vespalib::duration _visibilityDelay; - bool _hasIndexedOrAttributeFields; -public: - LidReuseDelayerConfig(); - LidReuseDelayerConfig(vespalib::duration visibilityDelay, bool _hasIndexedOrAttributeFields_in); - explicit LidReuseDelayerConfig(const DocumentDBConfig &configSnapshot); - vespalib::duration visibilityDelay() const { return _visibilityDelay; } - bool hasIndexedOrAttributeFields() const { return _hasIndexedOrAttributeFields; } -}; - -} diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/lidreusedelayer.cpp b/searchcore/src/vespa/searchcore/proton/documentmetastore/lidreusedelayer.cpp index 03dfd83a132..003812589d1 100644 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/lidreusedelayer.cpp +++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/lidreusedelayer.cpp @@ -11,12 +11,9 @@ using searchcorespi::index::IThreadingService; using vespalib::makeClosure; using vespalib::makeTask; -LidReuseDelayer::LidReuseDelayer(IThreadingService &writeService, IStore &documentMetaStore, - const LidReuseDelayerConfig & config) +LidReuseDelayer::LidReuseDelayer(IThreadingService &writeService, IStore &documentMetaStore) : _writeService(writeService), _documentMetaStore(documentMetaStore), - _immediateCommit(config.visibilityDelay() == vespalib::duration::zero()), - _config(config), _pendingLids() { } @@ -31,15 +28,8 @@ LidReuseDelayer::delayReuse(uint32_t lid) assert(_writeService.master().isCurrentThread()); if ( ! _documentMetaStore.getFreeListActive()) return false; - if ( ! _immediateCommit) { - _pendingLids.push_back(lid); - return false; - } - if ( ! _config.hasIndexedOrAttributeFields() ) { - _documentMetaStore.removeComplete(lid); - return false; - } - return true; + _pendingLids.push_back(lid); + return false; } bool @@ -48,15 +38,8 @@ LidReuseDelayer::delayReuse(const std::vector<uint32_t> &lids) assert(_writeService.master().isCurrentThread()); if ( ! _documentMetaStore.getFreeListActive() || lids.empty()) return false; - if ( ! _immediateCommit) { - _pendingLids.insert(_pendingLids.end(), lids.cbegin(), lids.cend()); - return false; - } - if ( ! _config.hasIndexedOrAttributeFields()) { - _documentMetaStore.removeBatchComplete(lids); - return false; - } - return true; + _pendingLids.insert(_pendingLids.end(), lids.cbegin(), lids.cend()); + return false; } std::vector<uint32_t> diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/lidreusedelayer.h b/searchcore/src/vespa/searchcore/proton/documentmetastore/lidreusedelayer.h index 0fe16636e1d..07cfbda1dba 100644 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/lidreusedelayer.h +++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/lidreusedelayer.h @@ -2,8 +2,8 @@ #pragma once -#include "lid_reuse_delayer_config.h" #include <vector> +#include <cstdint> namespace searchcorespi::index { struct IThreadingService; } @@ -26,19 +26,14 @@ class LidReuseDelayer { searchcorespi::index::IThreadingService &_writeService; IStore &_documentMetaStore; - const bool _immediateCommit; - LidReuseDelayerConfig _config; std::vector<uint32_t> _pendingLids; // lids waiting for commit public: - LidReuseDelayer(searchcorespi::index::IThreadingService &writeService, IStore &documentMetaStore, - const LidReuseDelayerConfig & config); + LidReuseDelayer(searchcorespi::index::IThreadingService &writeService, IStore &documentMetaStore); ~LidReuseDelayer(); bool delayReuse(uint32_t lid); bool delayReuse(const std::vector<uint32_t> &lids); std::vector<uint32_t> getReuseLids(); - - const LidReuseDelayerConfig & getConfig() const { return _config; } }; } diff --git a/searchcore/src/vespa/searchcore/proton/persistenceengine/commit_and_wait_document_retriever.cpp b/searchcore/src/vespa/searchcore/proton/persistenceengine/commit_and_wait_document_retriever.cpp index daa240e8b12..aa20627600f 100644 --- a/searchcore/src/vespa/searchcore/proton/persistenceengine/commit_and_wait_document_retriever.cpp +++ b/searchcore/src/vespa/searchcore/proton/persistenceengine/commit_and_wait_document_retriever.cpp @@ -5,10 +5,9 @@ namespace proton { -CommitAndWaitDocumentRetriever::CommitAndWaitDocumentRetriever(IDocumentRetriever::SP retriever, ICommitable &commit, +CommitAndWaitDocumentRetriever::CommitAndWaitDocumentRetriever(IDocumentRetriever::SP retriever, ILidCommitState & unCommittedLidTracker) : _retriever(std::move(retriever)), - _commit(commit), _uncommittedLidsTracker(unCommittedLidTracker) { } @@ -32,7 +31,7 @@ CommitAndWaitDocumentRetriever::getDocumentMetaData(const document::DocumentId & document::Document::UP CommitAndWaitDocumentRetriever::getFullDocument(search::DocumentIdT lid) const { // Ensure that attribute vectors are committed - _commit.commitAndWait(_uncommittedLidsTracker, lid); + _uncommittedLidsTracker.waitComplete(lid); return _retriever->getFullDocument(lid); } @@ -40,7 +39,7 @@ document::Document::UP CommitAndWaitDocumentRetriever::getPartialDocument(search::DocumentIdT lid, const document::DocumentId & docId, const document::FieldSet & fieldSet) const { - _commit.commitAndWait(_uncommittedLidsTracker, lid); + _uncommittedLidsTracker.waitComplete(lid); return _retriever->getPartialDocument(lid, docId, fieldSet); } @@ -48,7 +47,7 @@ void CommitAndWaitDocumentRetriever::visitDocuments(const LidVector &lids, search::IDocumentVisitor &visitor, ReadConsistency readConsistency) const { - _commit.commitAndWait(_uncommittedLidsTracker, lids); + _uncommittedLidsTracker.waitComplete(lids); _retriever->visitDocuments(lids, visitor, readConsistency); } diff --git a/searchcore/src/vespa/searchcore/proton/persistenceengine/commit_and_wait_document_retriever.h b/searchcore/src/vespa/searchcore/proton/persistenceengine/commit_and_wait_document_retriever.h index 8e1ac08fa20..68f34c65362 100644 --- a/searchcore/src/vespa/searchcore/proton/persistenceengine/commit_and_wait_document_retriever.h +++ b/searchcore/src/vespa/searchcore/proton/persistenceengine/commit_and_wait_document_retriever.h @@ -4,7 +4,6 @@ #include "i_document_retriever.h" #include <vespa/searchcore/proton/common/ipendinglidtracker.h> -#include <vespa/searchcore/proton/common/icommitable.h> namespace proton { @@ -16,11 +15,10 @@ namespace proton { class CommitAndWaitDocumentRetriever : public IDocumentRetriever { IDocumentRetriever::SP _retriever; - ICommitable &_commit; ILidCommitState &_uncommittedLidsTracker; using Bucket = storage::spi::Bucket; public: - CommitAndWaitDocumentRetriever(IDocumentRetriever::SP retriever, ICommitable &commit, ILidCommitState & unCommittedLidTracker); + CommitAndWaitDocumentRetriever(IDocumentRetriever::SP retriever, ILidCommitState & unCommittedLidTracker); ~CommitAndWaitDocumentRetriever() override; const document::DocumentTypeRepo &getDocumentTypeRepo() const override; diff --git a/searchcore/src/vespa/searchcore/proton/server/CMakeLists.txt b/searchcore/src/vespa/searchcore/proton/server/CMakeLists.txt index 5b9269917d7..93432221e61 100644 --- a/searchcore/src/vespa/searchcore/proton/server/CMakeLists.txt +++ b/searchcore/src/vespa/searchcore/proton/server/CMakeLists.txt @@ -103,7 +103,6 @@ vespa_add_library(searchcore_server STATIC transactionlogmanager.cpp transactionlogmanagerbase.cpp updatedonecontext.cpp - visibilityhandler.cpp DEPENDS searchcore_attribute searchcore_bucketdb diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp index b35e1a88495..a176b1cf8c2 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp @@ -166,7 +166,6 @@ DocumentDB::DocumentDB(const vespalib::string &baseDir, _writeFilter(), _transient_memory_usage_provider(std::make_shared<TransientMemoryUsageProvider>()), _feedHandler(std::make_unique<FeedHandler>(_writeService, tlsSpec, docTypeName, *this, _writeFilter, *this, tlsWriterFactory)), - _visibility(*_feedHandler, _writeService, _feedView), _subDBs(*this, *this, *_feedHandler, _docTypeName, _writeService, warmupExecutor, fileHeaderContext, metricsWireService, getMetrics(), queryLimiter, clock, _configMutex, _baseDir, makeSubDBConfig(protonCfg.distribution, @@ -208,8 +207,6 @@ DocumentDB::DocumentDB(const vespalib::string &baseDir, _lidSpaceCompactionHandlers.push_back(std::make_unique<LidSpaceCompactionHandler>(_maintenanceController.getNotReadySubDB(), _docTypeName.getName())); _writeFilter.setConfig(loaded_config->getMaintenanceConfigSP()->getAttributeUsageFilterConfig()); - vespalib::duration visibilityDelay = loaded_config->getMaintenanceConfigSP()->getVisibilityDelay(); - _visibility.setVisibilityDelay(visibilityDelay); } void DocumentDB::registerReference() @@ -441,21 +438,17 @@ DocumentDB::applyConfig(DocumentDBConfig::SP configSnapshot, SerialNum serialNum commit_result = _feedHandler->storeOperationSync(op); sync(op.getSerialNum()); } - bool hasVisibilityDelayChanged = false; { bool elidedConfigSave = equalReplayConfig && tlsReplayDone; // Flush changes to attributes and memory index, cf. visibilityDelay _feedView.get()->forceCommit(elidedConfigSave ? serialNum : serialNum - 1, std::make_shared<search::KeepAlive<FeedHandler::CommitResult>>(std::move(commit_result))); _writeService.sync(); - vespalib::duration visibilityDelay = configSnapshot->getMaintenanceConfigSP()->getVisibilityDelay(); - hasVisibilityDelayChanged = (visibilityDelay != _visibility.getVisibilityDelay()); - _visibility.setVisibilityDelay(visibilityDelay); } if (_state.getState() >= DDBState::State::APPLY_LIVE_CONFIG) { _writeServiceConfig.update(configSnapshot->get_threading_service_config()); } _writeService.setTaskLimit(_writeServiceConfig.defaultTaskLimit(), _writeServiceConfig.defaultTaskLimit()); - if (params.shouldSubDbsChange() || hasVisibilityDelayChanged) { + if (params.shouldSubDbsChange()) { applySubDBConfig(*configSnapshot, serialNum, params); if (serialNum < _feedHandler->getSerialNum()) { // Not last entry in tls. Reprocessing should already be done. @@ -562,7 +555,6 @@ DocumentDB::close() // Abort any ongoing maintenance stopMaintenance(); - _visibility.commit(); _writeService.sync(); // The attributes in the ready sub db is also the total set of attributes. @@ -738,7 +730,7 @@ BucketGuard::UP DocumentDB::lockBucket(const document::BucketId &bucket) std::shared_ptr<std::vector<IDocumentRetriever::SP> > DocumentDB::getDocumentRetrievers(IDocumentRetriever::ReadConsistency consistency) { - return _subDBs.getRetrievers(consistency, _visibility); + return _subDBs.getRetrievers(consistency); } SerialNum @@ -908,22 +900,11 @@ DocumentDB::syncFeedView() IFeedView::SP newFeedView(_subDBs.getFeedView()); _writeService.sync(); - /* - * Don't call commit() on visibility handler during transaction - * log replay since the serial number used for the commit will be - * too high until the replay is complete. This check can be - * removed again when feed handler has improved tracking of serial - * numbers during replay. - */ - if (_state.getAllowReconfig()) { - _visibility.commit(); - } - _writeService.sync(); _feedView.set(newFeedView); _feedHandler->setActiveFeedView(newFeedView.get()); _subDBs.createRetrievers(); - _subDBs.maintenanceSync(_maintenanceController, _visibility); + _subDBs.maintenanceSync(_maintenanceController); // Ensure that old feed view is referenced until all index executor tasks // depending on it has completed. diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdb.h b/searchcore/src/vespa/searchcore/proton/server/documentdb.h index 4c4840446fe..c94b8ffca46 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdb.h +++ b/searchcore/src/vespa/searchcore/proton/server/documentdb.h @@ -18,8 +18,6 @@ #include "ireplayconfig.h" #include "maintenancecontroller.h" #include "threading_service_config.h" -#include "visibilityhandler.h" - #include <vespa/metrics/updatehook.h> #include <vespa/searchcore/proton/attribute/attribute_usage_filter.h> #include <vespa/searchcore/proton/common/doctypename.h> @@ -139,7 +137,6 @@ private: AttributeUsageFilter _writeFilter; std::shared_ptr<TransientMemoryUsageProvider> _transient_memory_usage_provider; std::unique_ptr<FeedHandler> _feedHandler; - VisibilityHandler _visibility; DocumentSubDBCollection _subDBs; MaintenanceController _maintenanceController; ILidSpaceCompactionHandler::Vector _lidSpaceCompactionHandlers; diff --git a/searchcore/src/vespa/searchcore/proton/server/documentsubdbcollection.cpp b/searchcore/src/vespa/searchcore/proton/server/documentsubdbcollection.cpp index d2b3c7b9d1d..bb1cbcf9371 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentsubdbcollection.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentsubdbcollection.cpp @@ -129,26 +129,26 @@ DocumentSubDBCollection::createRetrievers() namespace { IDocumentRetriever::SP -wrapRetriever(IDocumentRetriever::SP retriever, ICommitable &commit, ILidCommitState & unCommitedLidsTracker) +wrapRetriever(IDocumentRetriever::SP retriever, ILidCommitState & unCommitedLidsTracker) { - return std::make_shared<CommitAndWaitDocumentRetriever>(std::move(retriever), commit, unCommitedLidsTracker); + return std::make_shared<CommitAndWaitDocumentRetriever>(std::move(retriever), unCommitedLidsTracker); } } DocumentSubDBCollection::RetrieversSP -DocumentSubDBCollection::getRetrievers(IDocumentRetriever::ReadConsistency consistency, ICommitable & visibilityHandler) { +DocumentSubDBCollection::getRetrievers(IDocumentRetriever::ReadConsistency consistency) { RetrieversSP list = _retrievers.get(); if (consistency == IDocumentRetriever::ReadConsistency::STRONG) { auto wrappedList = std::make_shared<std::vector<IDocumentRetriever::SP>>(); wrappedList->reserve(list->size()); assert(list->size() == 3); - wrappedList->push_back(wrapRetriever((*list)[_readySubDbId], visibilityHandler, + wrappedList->push_back(wrapRetriever((*list)[_readySubDbId], getReadySubDB()->getFeedView()->getUncommittedLidsTracker())); - wrappedList->push_back(wrapRetriever((*list)[_remSubDbId], visibilityHandler, + wrappedList->push_back(wrapRetriever((*list)[_remSubDbId], getRemSubDB()->getFeedView()->getUncommittedLidsTracker())); - wrappedList->push_back(wrapRetriever((*list)[_notReadySubDbId], visibilityHandler, + wrappedList->push_back(wrapRetriever((*list)[_notReadySubDbId], getNotReadySubDB()->getFeedView()->getUncommittedLidsTracker())); return wrappedList; } else { @@ -156,23 +156,23 @@ DocumentSubDBCollection::getRetrievers(IDocumentRetriever::ReadConsistency consi } } -void DocumentSubDBCollection::maintenanceSync(MaintenanceController &mc, ICommitable &commit) { +void DocumentSubDBCollection::maintenanceSync(MaintenanceController &mc) { RetrieversSP retrievers = _retrievers.get(); MaintenanceDocumentSubDB readySubDB(getReadySubDB()->getName(), _readySubDbId, getReadySubDB()->getDocumentMetaStoreContext().getSP(), - wrapRetriever((*retrievers)[_readySubDbId], commit, + wrapRetriever((*retrievers)[_readySubDbId], getReadySubDB()->getFeedView()->getUncommittedLidsTracker()), getReadySubDB()->getFeedView()); MaintenanceDocumentSubDB remSubDB(getRemSubDB()->getName(), _remSubDbId, getRemSubDB()->getDocumentMetaStoreContext().getSP(), - wrapRetriever((*retrievers)[_remSubDbId], commit, getRemSubDB()->getFeedView()->getUncommittedLidsTracker()), + wrapRetriever((*retrievers)[_remSubDbId], getRemSubDB()->getFeedView()->getUncommittedLidsTracker()), getRemSubDB()->getFeedView()); MaintenanceDocumentSubDB notReadySubDB(getNotReadySubDB()->getName(), _notReadySubDbId, getNotReadySubDB()->getDocumentMetaStoreContext().getSP(), - wrapRetriever((*retrievers)[_notReadySubDbId], commit, + wrapRetriever((*retrievers)[_notReadySubDbId], getNotReadySubDB()->getFeedView()->getUncommittedLidsTracker()), getNotReadySubDB()->getFeedView()); mc.syncSubDBs(readySubDB, remSubDB, notReadySubDB); diff --git a/searchcore/src/vespa/searchcore/proton/server/documentsubdbcollection.h b/searchcore/src/vespa/searchcore/proton/server/documentsubdbcollection.h index 83ebef18274..317ec191d60 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentsubdbcollection.h +++ b/searchcore/src/vespa/searchcore/proton/server/documentsubdbcollection.h @@ -32,7 +32,6 @@ class DocumentDBConfig; struct DocumentDBTaggedMetrics; class MaintenanceController; struct MetricsWireService; -class ICommitable; struct IDocumentDBReferenceResolver; class IGetSerialNum; class DocTypeName; @@ -119,10 +118,10 @@ public: void setBucketStateCalculator(const IBucketStateCalculatorSP &calc); void createRetrievers(); - void maintenanceSync(MaintenanceController &mc, ICommitable &commit); + void maintenanceSync(MaintenanceController &mc); // Internally synchronized - RetrieversSP getRetrievers(IDocumentRetriever::ReadConsistency consistency, ICommitable & visibilityHandler); + RetrieversSP getRetrievers(IDocumentRetriever::ReadConsistency consistency); IDocumentSubDB *getReadySubDB() { return _subDBs[_readySubDbId]; } const IDocumentSubDB *getReadySubDB() const { return _subDBs[_readySubDbId]; } diff --git a/searchcore/src/vespa/searchcore/proton/server/fast_access_doc_subdb_configurer.cpp b/searchcore/src/vespa/searchcore/proton/server/fast_access_doc_subdb_configurer.cpp index e0d9f28252f..ab257d56848 100644 --- a/searchcore/src/vespa/searchcore/proton/server/fast_access_doc_subdb_configurer.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/fast_access_doc_subdb_configurer.cpp @@ -18,8 +18,7 @@ void FastAccessDocSubDBConfigurer::reconfigureFeedView(const FastAccessFeedView::SP &curr, const Schema::SP &schema, const std::shared_ptr<const DocumentTypeRepo> &repo, - IAttributeWriter::SP writer, - const LidReuseDelayerConfig & lidReuseDelayerConfig) + IAttributeWriter::SP writer) { _feedView.set(std::make_shared<FastAccessFeedView>( StoreOnlyFeedView::Context(curr->getSummaryAdapter(), @@ -27,8 +26,7 @@ FastAccessDocSubDBConfigurer::reconfigureFeedView(const FastAccessFeedView::SP & curr->getDocumentMetaStore(), curr->getGidToLidChangeHandler(), repo, - curr->getWriteService(), - lidReuseDelayerConfig), + curr->getWriteService()), curr->getPersistentParams(), FastAccessFeedView::Context(std::move(writer),curr->getDocIdLimit()))); } @@ -51,7 +49,7 @@ FastAccessDocSubDBConfigurer::reconfigure(const DocumentDBConfig &newConfig, { FastAccessFeedView::SP oldView = _feedView.get(); IAttributeWriter::SP writer = _factory->create(oldView->getAttributeWriter(), attrSpec); - reconfigureFeedView(oldView, newConfig.getSchemaSP(), newConfig.getDocumentTypeRepoSP(), writer, LidReuseDelayerConfig(newConfig)); + reconfigureFeedView(oldView, newConfig.getSchemaSP(), newConfig.getDocumentTypeRepoSP(), writer); const document::DocumentType *newDocType = newConfig.getDocumentType(); const document::DocumentType *oldDocType = oldConfig.getDocumentType(); diff --git a/searchcore/src/vespa/searchcore/proton/server/fast_access_doc_subdb_configurer.h b/searchcore/src/vespa/searchcore/proton/server/fast_access_doc_subdb_configurer.h index 2c07d904339..dc54bdc421d 100644 --- a/searchcore/src/vespa/searchcore/proton/server/fast_access_doc_subdb_configurer.h +++ b/searchcore/src/vespa/searchcore/proton/server/fast_access_doc_subdb_configurer.h @@ -5,7 +5,6 @@ #include "fast_access_feed_view.h" #include "i_attribute_writer_factory.h" #include <vespa/searchcore/proton/reprocessing/i_reprocessing_initializer.h> -#include <vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.h> namespace proton { @@ -17,7 +16,6 @@ class FastAccessDocSubDBConfigurer { public: using FeedViewVarHolder = vespalib::VarHolder<FastAccessFeedView::SP>; - using LidReuseDelayerConfig = documentmetastore::LidReuseDelayerConfig; private: FeedViewVarHolder &_feedView; @@ -27,8 +25,7 @@ private: void reconfigureFeedView(const FastAccessFeedView::SP &curr, const search::index::Schema::SP &schema, const std::shared_ptr<const document::DocumentTypeRepo> &repo, - IAttributeWriter::SP attrWriter, - const LidReuseDelayerConfig & lidReuseDelayerConfig); + IAttributeWriter::SP attrWriter); public: FastAccessDocSubDBConfigurer(FeedViewVarHolder &feedView, diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenance_jobs_injector.h b/searchcore/src/vespa/searchcore/proton/server/maintenance_jobs_injector.h index 44308d49dab..3468ec40923 100644 --- a/searchcore/src/vespa/searchcore/proton/server/maintenance_jobs_injector.h +++ b/searchcore/src/vespa/searchcore/proton/server/maintenance_jobs_injector.h @@ -6,7 +6,6 @@ #include "i_lid_space_compaction_handler.h" #include "i_operation_storer.h" #include "iheartbeathandler.h" -#include <vespa/searchcore/proton/common/icommitable.h> #include <vespa/searchcore/proton/matching/isessioncachepruner.h> #include <vespa/searchcore/proton/metrics/documentdb_job_trackers.h> diff --git a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp index 8f34484dfe2..7a15d7122c6 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp @@ -26,39 +26,21 @@ using matching::OnnxModels; typedef AttributeReprocessingInitializer::Config ARIConfig; void -SearchableDocSubDBConfigurer::reconfigureFeedView(const SearchView::SP &searchView) -{ - SearchableFeedView::SP curr = _feedView.get(); - reconfigureFeedView(curr->getIndexWriter(), - curr->getSummaryAdapter(), - curr->getAttributeWriter(), - curr->getSchema(), - curr->getDocumentTypeRepo(), - searchView, - curr->getLidReuseDelayerConfig()); -} - -void -SearchableDocSubDBConfigurer::reconfigureFeedView(const IIndexWriter::SP &indexWriter, - const ISummaryAdapter::SP &summaryAdapter, - IAttributeWriter::SP attrWriter, +SearchableDocSubDBConfigurer::reconfigureFeedView(IAttributeWriter::SP attrWriter, const Schema::SP &schema, - const std::shared_ptr<const DocumentTypeRepo> &repo, - const SearchView::SP &searchView, - const LidReuseDelayerConfig & lidReuseDelayerConfig) + const std::shared_ptr<const DocumentTypeRepo> &repo) { SearchableFeedView::SP curr = _feedView.get(); _feedView.set(std::make_shared<SearchableFeedView>( - StoreOnlyFeedView::Context(summaryAdapter, + StoreOnlyFeedView::Context(curr->getSummaryAdapter(), schema, - searchView->getDocumentMetaStore(), + curr->getDocumentMetaStore(), curr->getGidToLidChangeHandler(), repo, - curr->getWriteService(), - lidReuseDelayerConfig), + curr->getWriteService()), curr->getPersistentParams(), FastAccessFeedView::Context(std::move(attrWriter), curr->getDocIdLimit()), - SearchableFeedView::Context(indexWriter))); + SearchableFeedView::Context(curr->getIndexWriter()))); } void @@ -147,8 +129,6 @@ SearchableDocSubDBConfigurer::reconfigureIndexSearchable() const IIndexWriter::SP &indexWriter = feedView->getIndexWriter(); const searchcorespi::IIndexManager::SP &indexManager = indexWriter->getIndexManager(); reconfigureMatchView(indexManager->getSearchable()); - const SearchView::SP searchView(_searchView.get()); - reconfigureFeedView(searchView); } void @@ -249,7 +229,6 @@ SearchableDocSubDBConfigurer::reconfigure(const DocumentDBConfig &newConfig, IndexSearchable::SP indexSearchable = searchView->getIndexSearchable(); reconfigureMatchView(matchers, indexSearchable, attrMgr); searchView = _searchView.get(); - shouldFeedViewChange = true; } if (shouldSearchViewChange) { @@ -257,14 +236,9 @@ SearchableDocSubDBConfigurer::reconfigure(const DocumentDBConfig &newConfig, } if (shouldFeedViewChange) { - SearchableFeedView::SP curr = _feedView.get(); - reconfigureFeedView(curr->getIndexWriter(), - curr->getSummaryAdapter(), - std::move(attrWriter), + reconfigureFeedView(std::move(attrWriter), newConfig.getSchemaSP(), - newConfig.getDocumentTypeRepoSP(), - searchView, - LidReuseDelayerConfig(newConfig)); + newConfig.getDocumentTypeRepoSP()); } return initializer; } diff --git a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h index 0f86520fd0b..6fe0826d578 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h +++ b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h @@ -36,7 +36,6 @@ class SearchableDocSubDBConfigurer private: typedef vespalib::VarHolder<SearchView::SP> SearchViewHolder; typedef vespalib::VarHolder<SearchableFeedView::SP> FeedViewHolder; - using LidReuseDelayerConfig = documentmetastore::LidReuseDelayerConfig; const ISummaryManager::SP &_summaryMgr; SearchViewHolder &_searchView; FeedViewHolder &_feedView; @@ -46,15 +45,9 @@ private: vespalib::string _subDbName; uint32_t _distributionKey; - void reconfigureFeedView(const SearchView::SP &searchView); - - void reconfigureFeedView(const IIndexWriter::SP &indexWriter, - const ISummaryAdapter::SP &summaryAdapter, - IAttributeWriter::SP attrWriter, + void reconfigureFeedView(IAttributeWriter::SP attrWriter, const search::index::Schema::SP &schema, - const std::shared_ptr<const document::DocumentTypeRepo> &repo, - const SearchView::SP &searchView, - const LidReuseDelayerConfig & lidReuseDelayerConfig); + const std::shared_ptr<const document::DocumentTypeRepo> &repo); void reconfigureMatchView(const searchcorespi::IndexSearchable::SP &indexSearchable); diff --git a/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp b/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp index 481fe799f8f..7f2b8fcaa63 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp @@ -6,7 +6,6 @@ #include "reconfig_params.h" #include "i_document_subdb_owner.h" #include "ibucketstatecalculator.h" -#include <vespa/searchcore/proton/common/icommitable.h> #include <vespa/searchcore/proton/attribute/attribute_writer.h> #include <vespa/searchcore/proton/flushengine/threadedflushtarget.h> #include <vespa/searchcore/proton/index/index_manager_initializer.h> @@ -17,7 +16,6 @@ #include <vespa/searchlib/fef/indexproperties.h> #include <vespa/searchlib/fef/properties.h> #include <vespa/vespalib/util/closuretask.h> -#include <vespa/eval/tensor/default_tensor_engine.h> using vespa::config::search::RankProfilesConfig; using proton::matching::MatchingStats; @@ -242,9 +240,8 @@ SearchableDocSubDB::initFeedView(IAttributeWriter::SP attrWriter, /** * Handle reconfigure caused by index manager changing state. * - * Flush engine is disabled (for all document dbs) during initial replay and - * recovery feed modes, the flush engine has not started. For a resurrected - * document type, flushing might occur during replay. + * Flush engine is disabled (for all document dbs) during initial replay, the + * flush engine has not started. */ bool SearchableDocSubDB::reconfigure(vespalib::Closure0<bool>::UP closure) @@ -256,7 +253,6 @@ SearchableDocSubDB::reconfigure(vespalib::Closure0<bool>::UP closure) // Everything should be quiet now. SearchView::SP oldSearchView = _rSearchView.get(); - IFeedView::SP oldFeedView = _iFeedView.get(); bool ret = true; diff --git a/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.h b/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.h index 0fcf9b99718..4e021e74189 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.h +++ b/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.h @@ -26,7 +26,6 @@ class DocumentDBConfig; struct IDocumentDBReferenceResolver; struct MetricsWireService; class GidToLidChangeHandler; -class ICommitable; /** * The searchable sub database supports searching and keeps all attribute fields in memory and diff --git a/searchcore/src/vespa/searchcore/proton/server/storeonlydocsubdb.cpp b/searchcore/src/vespa/searchcore/proton/server/storeonlydocsubdb.cpp index 07ae2adba99..89d244a1176 100644 --- a/searchcore/src/vespa/searchcore/proton/server/storeonlydocsubdb.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/storeonlydocsubdb.cpp @@ -44,7 +44,6 @@ using vespalib::GenericHeader; using search::common::FileHeaderContext; using proton::initializer::InitializerTask; using searchcorespi::IFlushTarget; -using proton::documentmetastore::LidReuseDelayerConfig; namespace proton { @@ -339,7 +338,7 @@ StoreOnlyDocSubDB::getStoreOnlyFeedViewContext(const DocumentDBConfig &configSna { return StoreOnlyFeedView::Context(getSummaryAdapter(), configSnapshot.getSchemaSP(), _metaStoreCtx, *_gidToLidChangeHandler, configSnapshot.getDocumentTypeRepoSP(), - _writeService, LidReuseDelayerConfig(configSnapshot)); + _writeService); } StoreOnlyFeedView::PersistentParams diff --git a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp index 56f25b11d9f..9fabdb3cd6c 100644 --- a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp @@ -223,7 +223,7 @@ StoreOnlyFeedView::StoreOnlyFeedView(const Context &ctx, const PersistentParams _documentMetaStoreContext(ctx._documentMetaStoreContext), _repo(ctx._repo), _docType(nullptr), - _lidReuseDelayer(ctx._writeService, _documentMetaStoreContext->get(), ctx._lidReuseDelayerConfig), + _lidReuseDelayer(ctx._writeService, _documentMetaStoreContext->get()), _pendingLidsForDocStore(), _pendingLidsForCommit(createUncommitedLidTracker()), _schema(ctx._schema), diff --git a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h index f2819170d5e..c497dea3a19 100644 --- a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h +++ b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h @@ -13,7 +13,6 @@ #include <vespa/searchcore/proton/common/feeddebugger.h> #include <vespa/searchcore/proton/documentmetastore/documentmetastore.h> #include <vespa/searchcore/proton/documentmetastore/documentmetastorecontext.h> -#include <vespa/searchcore/proton/documentmetastore/lid_reuse_delayer_config.h> #include <vespa/searchcore/proton/documentmetastore/lidreusedelayer.h> #include <vespa/searchcore/proton/feedoperation/lidvectorcontext.h> #include <vespa/searchcore/proton/persistenceengine/resulthandler.h> @@ -68,7 +67,6 @@ public: using PromisedStream = std::promise<vespalib::nbostream>; using DocumentSP = std::shared_ptr<Document>; using DocumentUpdateSP = std::shared_ptr<DocumentUpdate>; - using LidReuseDelayerConfig = documentmetastore::LidReuseDelayerConfig; using LidReuseDelayer = documentmetastore::LidReuseDelayer; using Lid = search::DocumentIdT; @@ -81,22 +79,19 @@ public: IGidToLidChangeHandler &_gidToLidChangeHandler; const std::shared_ptr<const document::DocumentTypeRepo> &_repo; searchcorespi::index::IThreadingService &_writeService; - LidReuseDelayerConfig _lidReuseDelayerConfig; Context(const ISummaryAdapter::SP &summaryAdapter, const search::index::Schema::SP &schema, const IDocumentMetaStoreContext::SP &documentMetaStoreContext, IGidToLidChangeHandler &gidToLidChangeHandler, const std::shared_ptr<const document::DocumentTypeRepo> &repo, - searchcorespi::index::IThreadingService &writeService, - const LidReuseDelayerConfig & lidReuseDelayerConfig) + searchcorespi::index::IThreadingService &writeService) : _summaryAdapter(summaryAdapter), _schema(schema), _documentMetaStoreContext(documentMetaStoreContext), _gidToLidChangeHandler(gidToLidChangeHandler), _repo(repo), - _writeService(writeService), - _lidReuseDelayerConfig(lidReuseDelayerConfig) + _writeService(writeService) {} }; @@ -222,7 +217,6 @@ public: const IDocumentMetaStoreContext::SP &getDocumentMetaStore() const { return _documentMetaStoreContext; } searchcorespi::index::IThreadingService &getWriteService() { return _writeService; } IGidToLidChangeHandler &getGidToLidChangeHandler() const { return _gidToLidChangeHandler; } - LidReuseDelayerConfig getLidReuseDelayerConfig() const { return _lidReuseDelayer.getConfig(); } const std::shared_ptr<const document::DocumentTypeRepo> &getDocumentTypeRepo() const override { return _repo; } const ISimpleDocumentMetaStore *getDocumentMetaStorePtr() const override; diff --git a/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp b/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp deleted file mode 100644 index 3a44af517ee..00000000000 --- a/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include "visibilityhandler.h" -#include <vespa/vespalib/util/isequencedtaskexecutor.h> -#include <vespa/vespalib/util/lambdatask.h> - -using vespalib::makeLambdaTask; - -namespace proton { - -VisibilityHandler::VisibilityHandler(const IGetSerialNum & serial, - IThreadingService &writeService, - const FeedViewHolder & feedView) - : _serial(serial), - _writeService(writeService), - _feedView(feedView), - _visibilityDelay(vespalib::duration::zero()), - _lastCommitSerialNum(0), - _lock() -{ -} - -VisibilityHandler::~VisibilityHandler() = default; - -void -VisibilityHandler::internalCommit(bool force) -{ - if (_writeService.master().isCurrentThread()) { - performCommit(force); - } else { - std::lock_guard<std::mutex> guard(_lock); - bool wasCommitTaskSpawned = startCommit(guard, force); - (void) wasCommitTaskSpawned; - } -} -void -VisibilityHandler::commit() -{ - if (hasVisibilityDelay()) { - internalCommit(true); - } -} - -void -VisibilityHandler::commitAndWait(ILidCommitState & unCommittedLidTracker) -{ - ILidCommitState::State state = unCommittedLidTracker.getState(); - if (state == ILidCommitState::State::NEED_COMMIT) { - internalCommit(false); - } - if (state != ILidCommitState::State::COMPLETED) { - unCommittedLidTracker.waitComplete(); - } -} - -void -VisibilityHandler::commitAndWait(ILidCommitState & unCommittedLidTracker, uint32_t lid) { - ILidCommitState::State state = unCommittedLidTracker.getState(lid); - if (state == ILidCommitState::State::NEED_COMMIT) { - internalCommit(false); - } - if (state != ILidCommitState::State::COMPLETED) { - unCommittedLidTracker.waitComplete(lid); - } -} -void -VisibilityHandler::commitAndWait(ILidCommitState & unCommittedLidTracker, const std::vector<uint32_t> & lids) { - ILidCommitState::State state = unCommittedLidTracker.getState(lids); - if (state == ILidCommitState::State::NEED_COMMIT) { - internalCommit(false); - } - if (state != ILidCommitState::State::COMPLETED) { - unCommittedLidTracker.waitComplete(lids); - } -} - -bool -VisibilityHandler::startCommit(const std::lock_guard<std::mutex> &unused, bool force) -{ - (void) unused; - SerialNum current = _serial.getSerialNum(); - if ((current > _lastCommitSerialNum) || force) { - _writeService.master().execute(makeLambdaTask([this, force]() { performCommit(force);})); - return true; - } - return false; -} - -void -VisibilityHandler::performCommit(bool force) -{ - // Called by master thread - SerialNum current = _serial.getSerialNum(); - if ((current > _lastCommitSerialNum) || force) { - IFeedView::SP feedView(_feedView.get()); - if (feedView) { - feedView->forceCommit(current); - _lastCommitSerialNum = current; - } - } -} - -} diff --git a/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.h b/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.h deleted file mode 100644 index c22128685c1..00000000000 --- a/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.h +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#pragma once - -#include "ifeedview.h" -#include "igetserialnum.h" -#include <vespa/searchcore/proton/common/icommitable.h> -#include <vespa/searchcorespi/index/ithreadingservice.h> -#include <vespa/vespalib/util/varholder.h> -#include <vespa/vespalib/util/time.h> -#include <mutex> - -namespace proton { - -/** - * Handle commit of changes withing the allowance of visibilitydelay. - * It will both handle background commit jobs and the necessary commit and wait for sequencing. - **/ -class VisibilityHandler : public ICommitable -{ - using IThreadingService = searchcorespi::index::IThreadingService; - using FeedViewHolder = vespalib::VarHolder<IFeedView::SP>; -public: - typedef search::SerialNum SerialNum; - VisibilityHandler(const IGetSerialNum &serial, - IThreadingService &threadingService, - const FeedViewHolder &feedView); - ~VisibilityHandler() override; - void setVisibilityDelay(vespalib::duration visibilityDelay) { _visibilityDelay = visibilityDelay; } - vespalib::duration getVisibilityDelay() const { return _visibilityDelay; } - bool hasVisibilityDelay() const { return _visibilityDelay != vespalib::duration::zero(); } - void commit() override; - void commitAndWait(ILidCommitState & unCommittedLidTracker) override; - void commitAndWait(ILidCommitState &, uint32_t ) override; - void commitAndWait(ILidCommitState &, const std::vector<uint32_t> & ) override; -private: - bool startCommit(const std::lock_guard<std::mutex> &unused, bool force); - void performCommit(bool force); - void internalCommit(bool force); - const IGetSerialNum & _serial; - IThreadingService & _writeService; - const FeedViewHolder & _feedView; - vespalib::duration _visibilityDelay; - SerialNum _lastCommitSerialNum; - std::mutex _lock; -}; - -} diff --git a/storage/src/tests/persistence/mergehandlertest.cpp b/storage/src/tests/persistence/mergehandlertest.cpp index 02527883022..335863322d9 100644 --- a/storage/src/tests/persistence/mergehandlertest.cpp +++ b/storage/src/tests/persistence/mergehandlertest.cpp @@ -985,7 +985,7 @@ MergeHandlerTest::HandleApplyBucketDiffReplyInvoker::beforeInvoke( auto cmd = std::make_shared<api::MergeBucketCommand>(test._bucket, test._nodes, test._maxTimestamp); handler.handleMergeBucket(*cmd, test.createTracker(cmd, test._bucket)); auto diffCmd = test.fetchSingleMessage<api::GetBucketDiffCommand>(); - auto dummyDiff = test.createDummyGetBucketDiff(100000 * _counter, 0x4); + auto dummyDiff = test.createDummyGetBucketDiff(100000 * _counter, 0x2); diffCmd->getDiff() = dummyDiff->getDiff(); api::GetBucketDiffReply diffReply(*diffCmd); @@ -1294,8 +1294,8 @@ TEST_F(MergeHandlerTest, partially_filled_apply_bucket_diff_reply) // Node 4 has been eliminated before the first ApplyBucketDiff command EXPECT_EQ((NodeList{{0, false}, {1, false}, {2, true}, {3, true}}), s.nodeList); EXPECT_EQ(baseline_diff_size + 2u, s.diff.size()); - EXPECT_EQ(EntryCheck(20000, 8u), s.diff[baseline_diff_size]); - EXPECT_EQ(EntryCheck(20100, 8u), s.diff[baseline_diff_size + 1]); + EXPECT_EQ(EntryCheck(20000, 24u), s.diff[baseline_diff_size]); + EXPECT_EQ(EntryCheck(20100, 24u), s.diff[baseline_diff_size + 1]); auto& cmd3 = dynamic_cast<api::ApplyBucketDiffCommand&>(*messageKeeper()._msgs[1]); // ApplyBucketDiffCommand has a shorter node list, node 2 is not present EXPECT_EQ((NodeList{{0, false}, {1, false}, {3, true}}), cmd3.getNodes()); @@ -1321,15 +1321,15 @@ TEST_F(MergeHandlerTest, partially_filled_apply_bucket_diff_reply) auto &s = getEnv()._fileStorHandler.editMergeStatus(_bucket); EXPECT_EQ((NodeList{{0, false}, {1, false}, {2, true}, {3, true}}), s.nodeList); EXPECT_EQ(baseline_diff_size + 1u, s.diff.size()); - EXPECT_EQ(EntryCheck(20100, 8u), s.diff[baseline_diff_size]); + EXPECT_EQ(EntryCheck(20100, 24u), s.diff[baseline_diff_size]); auto& cmd4 = dynamic_cast<api::ApplyBucketDiffCommand&>(*messageKeeper()._msgs[2]); EXPECT_EQ((NodeList{{0, false}, {1, false}, {3, true}}), cmd4.getNodes()); auto reply = std::make_unique<api::ApplyBucketDiffReply>(cmd4); auto& diff = reply->getDiff(); EXPECT_EQ(1u, diff.size()); EXPECT_EQ(EntryCheck(20100u, 4u), diff[0]._entry); - fill_entry(diff[0], *doc2, getEnv().getDocumentTypeRepo()); - diff[0]._entry._hasMask |= 2u; + // Simulate that node 3 somehow lost doc2 when trying to fill diff entry. + diff[0]._entry._hasMask &= ~4u; handler.handleApplyBucketDiffReply(*reply, messageKeeper()); LOG(debug, "handled second ApplyBucketDiffReply"); } @@ -1341,7 +1341,8 @@ TEST_F(MergeHandlerTest, partially_filled_apply_bucket_diff_reply) auto &s = getEnv()._fileStorHandler.editMergeStatus(_bucket); // Nodes 3 and 2 have been eliminated before the third ApplyBucketDiff command EXPECT_EQ((NodeList{{0, false}, {1, false}}), s.nodeList); - EXPECT_EQ(baseline_diff_size, s.diff.size()); + EXPECT_EQ(baseline_diff_size + 1u, s.diff.size()); + EXPECT_EQ(EntryCheck(20100, 16u), s.diff[baseline_diff_size]); auto& cmd5 = dynamic_cast<api::ApplyBucketDiffCommand&>(*messageKeeper()._msgs[3]); EXPECT_EQ((NodeList{{0, false}, {1, false}}), cmd5.getNodes()); auto reply = std::make_unique<api::ApplyBucketDiffReply>(cmd5); @@ -1355,7 +1356,27 @@ TEST_F(MergeHandlerTest, partially_filled_apply_bucket_diff_reply) LOG(debug, "handled third ApplyBucketDiffReply"); } ASSERT_EQ(5u, messageKeeper()._msgs.size()); - ASSERT_EQ(api::MessageType::MERGEBUCKET_REPLY, messageKeeper()._msgs[4]->getType()); + ASSERT_EQ(api::MessageType::APPLYBUCKETDIFF, messageKeeper()._msgs[4]->getType()); + { + LOG(debug, "checking fourth ApplyBucketDiff command"); + EXPECT_TRUE(getEnv()._fileStorHandler.isMerging(_bucket)); + auto &s = getEnv()._fileStorHandler.editMergeStatus(_bucket); + // All nodes in use again due to failure to fill diff entry for doc2 + EXPECT_EQ((NodeList{{0, false}, {1, false}, {2, true}, {3, true}, {4, true}}), s.nodeList); + EXPECT_EQ(1u, s.diff.size()); + EXPECT_EQ(EntryCheck(20100, 16u), s.diff[0]); + auto& cmd6 = dynamic_cast<api::ApplyBucketDiffCommand&>(*messageKeeper()._msgs[4]); + EXPECT_EQ((NodeList{{0, false}, {1, false}, {4, true}}), cmd6.getNodes()); + auto reply = std::make_unique<api::ApplyBucketDiffReply>(cmd6); + auto& diff = reply->getDiff(); + EXPECT_EQ(1u, diff.size()); + fill_entry(diff[0], *doc2, getEnv().getDocumentTypeRepo()); + diff[0]._entry._hasMask |= 2u; + handler.handleApplyBucketDiffReply(*reply, messageKeeper()); + LOG(debug, "handled fourth ApplyBucketDiffReply"); + } + ASSERT_EQ(6u, messageKeeper()._msgs.size()); + ASSERT_EQ(api::MessageType::MERGEBUCKET_REPLY, messageKeeper()._msgs[5]->getType()); LOG(debug, "got mergebucket reply"); } diff --git a/storage/src/vespa/storage/persistence/filestorage/mergestatus.cpp b/storage/src/vespa/storage/persistence/filestorage/mergestatus.cpp index 2ecef59b567..2e390db69be 100644 --- a/storage/src/vespa/storage/persistence/filestorage/mergestatus.cpp +++ b/storage/src/vespa/storage/persistence/filestorage/mergestatus.cpp @@ -12,7 +12,7 @@ namespace storage { MergeStatus::MergeStatus(const framework::Clock& clock, const metrics::LoadType& lt, api::StorageMessage::Priority priority, uint32_t traceLevel) - : reply(), nodeList(), maxTimestamp(0), diff(), pendingId(0), + : reply(), full_node_list(), nodeList(), maxTimestamp(0), diff(), pendingId(0), pendingGetDiff(), pendingApplyDiff(), timeout(0), startTime(clock), context(lt, priority, traceLevel) {} diff --git a/storage/src/vespa/storage/persistence/filestorage/mergestatus.h b/storage/src/vespa/storage/persistence/filestorage/mergestatus.h index 18ced81c280..51930f337c6 100644 --- a/storage/src/vespa/storage/persistence/filestorage/mergestatus.h +++ b/storage/src/vespa/storage/persistence/filestorage/mergestatus.h @@ -18,6 +18,7 @@ public: using SP = std::shared_ptr<MergeStatus>; std::shared_ptr<api::StorageReply> reply; + std::vector<api::MergeBucketCommand::Node> full_node_list; std::vector<api::MergeBucketCommand::Node> nodeList; framework::MicroSecTime maxTimestamp; std::deque<api::GetBucketDiffCommand::Entry> diff; diff --git a/storage/src/vespa/storage/persistence/mergehandler.cpp b/storage/src/vespa/storage/persistence/mergehandler.cpp index 6e7fc30bd6c..51b575548d8 100644 --- a/storage/src/vespa/storage/persistence/mergehandler.cpp +++ b/storage/src/vespa/storage/persistence/mergehandler.cpp @@ -651,18 +651,22 @@ MergeHandler::applyDiffLocally( } namespace { - void findCandidates(MergeStatus& status, bool constrictHasMask, uint16_t hasMask, + void findCandidates(MergeStatus& status, uint16_t active_nodes_mask, bool constrictHasMask, uint16_t hasMask, uint16_t newHasMask, api::ApplyBucketDiffCommand& cmd) { uint32_t chunkSize = 0; for (const auto& entry : status.diff) { - if (constrictHasMask && entry._hasMask != hasMask) { + uint16_t entry_has_mask = (entry._hasMask & active_nodes_mask); + if ((entry_has_mask == 0u) || + (constrictHasMask && (entry_has_mask != hasMask))) { continue; } chunkSize += entry._bodySize + entry._headerSize; cmd.getDiff().emplace_back(entry); if (constrictHasMask) { cmd.getDiff().back()._entry._hasMask = newHasMask; + } else { + cmd.getDiff().back()._entry._hasMask = entry_has_mask; } } } @@ -690,52 +694,70 @@ MergeHandler::processBucketMerge(const spi::Bucket& bucket, MergeStatus& status, LOG(spam, "Processing merge of %s. %u entries left to merge.", bucket.toString().c_str(), (uint32_t) status.diff.size()); std::shared_ptr<api::ApplyBucketDiffCommand> cmd; - - // If we still have a source only node, eliminate that one from the - // merge. - while (status.nodeList.back().sourceOnly) { - std::vector<api::MergeBucketCommand::Node> nodes; - for (const auto& node : status.nodeList) { - if (!node.sourceOnly) { - nodes.emplace_back(node); + std::map<uint16_t, uint32_t> counts; + + uint16_t active_nodes_mask; + do { + active_nodes_mask = (1u << status.nodeList.size()) - 1; + // If we still have a source only node, eliminate that one from the + // merge. + while (status.nodeList.back().sourceOnly) { + std::vector<api::MergeBucketCommand::Node> nodes; + for (const auto& node : status.nodeList) { + if (!node.sourceOnly) { + nodes.emplace_back(node); + } + } + nodes.push_back(status.nodeList.back()); + assert(nodes.size() > 1); + + cmd = std::make_shared<api::ApplyBucketDiffCommand>(bucket.getBucket(), nodes); + cmd->setAddress(createAddress(_clusterName, nodes[1].index)); + findCandidates(status, + active_nodes_mask, + true, + 1 << (status.nodeList.size() - 1), + 1 << (nodes.size() - 1), + *cmd); + if (cmd->getDiff().size() != 0) { + break; + } + cmd.reset(); + // If we found no data to merge from the last source only node, + // remove it and retry. + status.nodeList.pop_back(); + active_nodes_mask = (1u << status.nodeList.size()) - 1; + // If only one node left in the merge, return ok. + if (status.nodeList.size() == 1) { + LOG(debug, "Done with merge of %s as there is only one node " + "that is not source only left in the merge.", + bucket.toString().c_str()); + return status.reply; } } - nodes.push_back(status.nodeList.back()); - assert(nodes.size() > 1); - - cmd = std::make_shared<api::ApplyBucketDiffCommand>(bucket.getBucket(), nodes); - cmd->setAddress(createAddress(_clusterName, nodes[1].index)); - findCandidates(status, - true, - 1 << (status.nodeList.size() - 1), - 1 << (nodes.size() - 1), - *cmd); - if (cmd->getDiff().size() != 0) break; - cmd.reset(); - // If we found no data to merge from the last source only node, - // remove it and retry. (Clear it out of the hasmask such that we - // can match hasmask with operator==) - status.nodeList.pop_back(); - uint16_t mask = ~(1 << status.nodeList.size()); - for (auto& e : status.diff) { - e._hasMask &= mask; - } - // If only one node left in the merge, return ok. - if (status.nodeList.size() == 1) { - LOG(debug, "Done with merge of %s as there is only one node " - "that is not source only left in the merge.", - bucket.toString().c_str()); - return status.reply; + if (!cmd) { + // If we did not have a source only node, check if we have a path with + // many documents within it that we'll merge separately + counts.clear(); + for (const auto& e : status.diff) { + ++counts[e._hasMask & active_nodes_mask]; + } + if (counts.size() == 1 && + counts.begin()->first == 0u && + status.nodeList.size() < status.full_node_list.size()) { + // Diff not empty, but none of the remaining nodes have any merge entries. + // Bring back source only nodes that might still have merge entries. + status.nodeList = status.full_node_list; + continue; + } } - } - // If we did not have a source only node, check if we have a path with - // many documents within it that we'll merge separately + break; + } while (true); if (!cmd) { - std::map<uint16_t, uint32_t> counts; - for (const auto& e : status.diff) { - ++counts[e._hasMask]; - } for (const auto& e : counts) { + if (e.first == 0u) { + continue; + } if (e.second >= uint32_t(_commonMergeChainOptimalizationMinimumSize) || counts.size() == 1) { @@ -769,7 +791,7 @@ MergeHandler::processBucketMerge(const spi::Bucket& bucket, MergeStatus& status, cmd->setAddress(createAddress(_clusterName, nodes[1].index)); // Add all the metadata, and thus use big limit. Max // data to fetch parameter will control amount added. - findCandidates(status, true, e.first, newMask, *cmd); + findCandidates(status, active_nodes_mask, true, e.first, newMask, *cmd); break; } } @@ -780,7 +802,7 @@ MergeHandler::processBucketMerge(const spi::Bucket& bucket, MergeStatus& status, if ( ! cmd ) { cmd = std::make_shared<api::ApplyBucketDiffCommand>(bucket.getBucket(), status.nodeList); cmd->setAddress(createAddress(_clusterName, status.nodeList[1].index)); - findCandidates(status, false, 0, 0, *cmd); + findCandidates(status, active_nodes_mask, false, 0, 0, *cmd); } cmd->setPriority(status.context.getPriority()); cmd->setTimeout(status.timeout); @@ -868,6 +890,7 @@ MergeHandler::handleMergeBucket(api::MergeBucketCommand& cmd, MessageTracker::UP _clock, cmd.getLoadType(), cmd.getPriority(), cmd.getTrace().getLevel()); _env._fileStorHandler.addMergeStatus(bucket.getBucket(), s); + s->full_node_list = cmd.getNodes(); s->nodeList = cmd.getNodes(); s->maxTimestamp = Timestamp(cmd.getMaxTimestamp()); s->timeout = cmd.getTimeout(); diff --git a/vespalib/src/tests/rendezvous/rendezvous_test.cpp b/vespalib/src/tests/rendezvous/rendezvous_test.cpp index bce25692760..f4ec7870ad5 100644 --- a/vespalib/src/tests/rendezvous/rendezvous_test.cpp +++ b/vespalib/src/tests/rendezvous/rendezvous_test.cpp @@ -1,7 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/vespalib/util/rendezvous.h> +#include <vespa/vespalib/util/time.h> #include <utility> +#include <thread> using namespace vespalib; @@ -10,28 +12,45 @@ struct Value { Value() : value(42) {} }; -template <typename T> -struct Empty : Rendezvous<int, T> { - Empty(size_t n) : Rendezvous<int, T>(n) {} +template <typename T, bool ext_id> +struct Empty : Rendezvous<int, T, ext_id> { + Empty(size_t n) : Rendezvous<int, T, ext_id>(n) {} void mingle() override {} - T meet() { return this->rendezvous(0); } + T meet(size_t thread_id) { + if constexpr (ext_id) { + return this->rendezvous(0, thread_id); + } else { + (void) thread_id; + return this->rendezvous(0); + } + } }; -struct Add : Rendezvous<size_t, std::pair<size_t, size_t> > { - Add(size_t n) : Rendezvous<size_t, std::pair<size_t, size_t> >(n) {} +template <bool ext_id> +struct Add : Rendezvous<size_t, std::pair<size_t, size_t>, ext_id> { + using Super = Rendezvous<size_t, std::pair<size_t, size_t>, ext_id>; + using Super::size; + using Super::in; + using Super::out; + Add(size_t n) : Super(n) {} void mingle() override { size_t sum = 0; for (size_t i = 0; i < size(); ++i) { sum += in(i); } - for (size_t i = 0; i < size(); ++i) { + for (size_t i = 0; i < this->size(); ++i) { out(i) = std::make_pair(sum, in(0)); } } }; -struct Modify : Rendezvous<size_t, size_t> { - Modify(size_t n) : Rendezvous<size_t, size_t>(n) {} +template <bool ext_id> +struct Modify : Rendezvous<size_t, size_t, ext_id> { + using Super = Rendezvous<size_t, size_t, ext_id>; + using Super::size; + using Super::in; + using Super::out; + Modify(size_t n) : Super(n) {} void mingle() override { for (size_t i = 0; i < size(); ++i) { in(i) += 1; @@ -42,60 +61,122 @@ struct Modify : Rendezvous<size_t, size_t> { } }; -template <typename T> -struct Swap : Rendezvous<T, T> { - using Rendezvous<T, T>::in; - using Rendezvous<T, T>::out; - Swap() : Rendezvous<T, T>(2) {} +template <typename T, bool ext_id> +struct Swap : Rendezvous<T, T, ext_id> { + using Super = Rendezvous<T, T, ext_id>; + using Super::size; + using Super::in; + using Super::out; + Swap() : Super(2) {} void mingle() override { out(0) = std::move(in(1)); out(1) = std::move(in(0)); } }; +template <bool ext_id> +struct DetectId : Rendezvous<int, size_t, ext_id> { + using Super = Rendezvous<int, size_t, ext_id>; + using Super::size; + using Super::in; + using Super::out; + DetectId(size_t n) : Super(n) {} + void mingle() override { + for (size_t i = 0; i < size(); ++i) { + out(i) = i; + } + } + size_t meet(size_t thread_id) { + if constexpr (ext_id) { + return this->rendezvous(0, thread_id); + } else { + (void) thread_id; + return this->rendezvous(0); + } + } +}; + +struct Any : Rendezvous<bool, bool> { + Any(size_t n) : Rendezvous<bool, bool>(n) {} + void mingle() override { + bool result = false; + for (size_t i = 0; i < size(); ++i) { + result |= in(i); + } + for (size_t i = 0; i < size(); ++i) { + out(i) = result; + } + } + bool check(bool flag) { return this->rendezvous(flag); } +}; + TEST("require that creating an empty rendezvous will fail") { - EXPECT_EXCEPTION(Add(0), IllegalArgumentException, ""); + EXPECT_EXCEPTION(Add<false>(0), IllegalArgumentException, ""); + EXPECT_EXCEPTION(Add<true>(0), IllegalArgumentException, ""); } -TEST_F("require that a single thread can mingle with itself within a rendezvous", Add(1)) { +TEST_FF("require that a single thread can mingle with itself within a rendezvous", Add<false>(1), Add<true>(1)) { EXPECT_EQUAL(10u, f1.rendezvous(10).first); EXPECT_EQUAL(20u, f1.rendezvous(20).first); EXPECT_EQUAL(30u, f1.rendezvous(30).first); + EXPECT_EQUAL(10u, f2.rendezvous(10, thread_id).first); + EXPECT_EQUAL(20u, f2.rendezvous(20, thread_id).first); + EXPECT_EQUAL(30u, f2.rendezvous(30, thread_id).first); } -TEST_MT_F("require that rendezvous can mingle multiple threads", 10, Add(num_threads)) { +TEST_MT_FF("require that rendezvous can mingle multiple threads", 10, Add<false>(num_threads), Add<true>(num_threads)) { EXPECT_EQUAL(45u, f1.rendezvous(thread_id).first); + EXPECT_EQUAL(45u, f2.rendezvous(thread_id, thread_id).first); } -typedef Empty<Value> Empty1; -typedef Empty<size_t> Empty2; -TEST_MT_FF("require that unset rendezvous outputs are default constructed", 10, Empty1(num_threads), Empty2(num_threads)) { - EXPECT_EQUAL(42u, f1.meet().value); - EXPECT_EQUAL(0u, f2.meet()); +template <bool ext_id> using Empty1 = Empty<Value, ext_id>; +template <bool ext_id> using Empty2 = Empty<size_t, ext_id>; + +TEST_MT_FFFF("require that unset rendezvous outputs are default constructed", 10, + Empty1<false>(num_threads), Empty2<false>(num_threads), + Empty1<true>(num_threads), Empty2<true>(num_threads)) +{ + EXPECT_EQUAL(42u, f1.meet(thread_id).value); + EXPECT_EQUAL(0u, f2.meet(thread_id)); + EXPECT_EQUAL(42u, f3.meet(thread_id).value); + EXPECT_EQUAL(0u, f4.meet(thread_id)); } -TEST_MT_FF("require that mingle is not called until all threads are present", 3, Add(num_threads), - CountDownLatch(num_threads - 1)) +TEST_MT_FFFF("require that mingle is not called until all threads are present", 3, + Add<false>(num_threads), CountDownLatch(num_threads - 1), + Add<true>(num_threads), CountDownLatch(num_threads - 1)) { - if (thread_id == 0) { - EXPECT_FALSE(f2.await(20)); - EXPECT_EQUAL(3u, f1.rendezvous(thread_id).first); - EXPECT_TRUE(f2.await(25000)); - } else { - EXPECT_EQUAL(3u, f1.rendezvous(thread_id).first); - f2.countDown(); + for (bool ext_id: {false, true}) { + CountDownLatch &latch = ext_id ? f4 : f2; + if (thread_id == 0) { + EXPECT_FALSE(latch.await(20)); + if (ext_id) { + EXPECT_EQUAL(3u, f3.rendezvous(thread_id, thread_id).first); + } else { + EXPECT_EQUAL(3u, f1.rendezvous(thread_id).first); + } + EXPECT_TRUE(latch.await(25000)); + } else { + if (ext_id) { + EXPECT_EQUAL(3u, f3.rendezvous(thread_id, thread_id).first); + } else { + EXPECT_EQUAL(3u, f1.rendezvous(thread_id).first); + } + latch.countDown(); + } } } -TEST_MT_F("require that rendezvous can be used multiple times", 10, Add(num_threads)) { - EXPECT_EQUAL(45u, f1.rendezvous(thread_id).first); - EXPECT_EQUAL(45u, f1.rendezvous(thread_id).first); +TEST_MT_FF("require that rendezvous can be used multiple times", 10, Add<false>(num_threads), Add<true>(num_threads)) { EXPECT_EQUAL(45u, f1.rendezvous(thread_id).first); + EXPECT_EQUAL(45u, f2.rendezvous(thread_id, thread_id).first); EXPECT_EQUAL(45u, f1.rendezvous(thread_id).first); + EXPECT_EQUAL(45u, f2.rendezvous(thread_id, thread_id).first); EXPECT_EQUAL(45u, f1.rendezvous(thread_id).first); + EXPECT_EQUAL(45u, f2.rendezvous(thread_id, thread_id).first); } -TEST_MT_FF("require that rendezvous can be run with additional threads", 100, Add(10), CountDownLatch(10)) { +TEST_MT_FF("require that rendezvous can be run with additional threads", 100, Add<false>(10), CountDownLatch(10)) { auto res = f1.rendezvous(thread_id); TEST_BARRIER(); if (res.second == thread_id) { @@ -105,16 +186,46 @@ TEST_MT_FF("require that rendezvous can be run with additional threads", 100, Ad EXPECT_TRUE(f2.await(25000)); } -TEST_MT_F("require that mingle can modify its own copy of input values", 10, Modify(num_threads)) { +TEST_MT_FF("require that mingle can modify its own copy of input values", 10, Modify<false>(num_threads), Modify<true>(num_threads)) { size_t my_input = thread_id; - size_t my_output = f1.rendezvous(my_input); + size_t my_output1 = f1.rendezvous(my_input); + size_t my_output2 = f2.rendezvous(my_input, thread_id); EXPECT_EQUAL(my_input, thread_id); - EXPECT_EQUAL(my_output, thread_id + 1); + EXPECT_EQUAL(my_output1, thread_id + 1); + EXPECT_EQUAL(my_output2, thread_id + 1); } -TEST_MT_F("require that threads can exchange non-copyable state", 2, Swap<std::unique_ptr<size_t> >()) { - auto other = f1.rendezvous(std::make_unique<size_t>(thread_id)); - EXPECT_EQUAL(*other, 1 - thread_id); +using Swap_false = Swap<std::unique_ptr<size_t>,false>; +using Swap_true = Swap<std::unique_ptr<size_t>,true>; + +TEST_MT_FF("require that threads can exchange non-copyable state", 2, Swap_false(), Swap_true()) { + auto other1 = f1.rendezvous(std::make_unique<size_t>(thread_id)); + EXPECT_EQUAL(*other1, 1 - thread_id); + auto other2 = f2.rendezvous(std::make_unique<size_t>(thread_id), thread_id); + EXPECT_EQUAL(*other2, 1 - thread_id); +} + +TEST_MT_F("require that participation id can be explicitly defined", 10, DetectId<true>(num_threads)) { + for (size_t i = 0; i < 128; ++i) { + size_t my_id = f1.meet(thread_id); + EXPECT_EQUAL(my_id, thread_id); + } +} + +TEST_MT_FF("require that participation id is unstable when not explicitly defined", 10, DetectId<false>(num_threads), Any(num_threads)) { + bool id_mismatch = false; + size_t old_id = f1.meet(thread_id); + for (size_t i = 0; !id_mismatch; ++i) { + if ((i % num_threads) == thread_id) { + std::this_thread::sleep_for(std::chrono::milliseconds(i)); + } + size_t new_id = f1.meet(thread_id); + if (new_id != old_id) { + id_mismatch = true; + } + id_mismatch = f2.check(id_mismatch); + } + EXPECT_TRUE(id_mismatch); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/vespa/vespalib/util/rendezvous.h b/vespalib/src/vespa/vespalib/util/rendezvous.h index be589f57b9d..e7259117272 100644 --- a/vespalib/src/vespa/vespalib/util/rendezvous.h +++ b/vespalib/src/vespa/vespalib/util/rendezvous.h @@ -2,6 +2,7 @@ #pragma once +#include <type_traits> #include <condition_variable> #include <vector> @@ -18,7 +19,7 @@ namespace vespalib { * subclass needs to implement the mingle function to supply the * application logic. **/ -template <typename IN, typename OUT> +template <typename IN, typename OUT, bool external_id = false> class Rendezvous { private: @@ -36,6 +37,17 @@ private: **/ virtual void mingle() = 0; + /** + * lock-free version for when there is only one thread meeting + * itself. + **/ + void meet_self(IN &input, OUT &output); + + /** + * general version for when there are multiple threads meeting. + **/ + void meet_others(IN &input, OUT &output, size_t my_id, std::unique_lock<std::mutex> guard); + protected: /** * Obtain the number of input and output values to be handled by @@ -81,7 +93,22 @@ public: * @return output parameter for a single thread * @param input input parameter for a single thread **/ - OUT rendezvous(IN input); + template <bool ext_id = external_id> + typename std::enable_if<!ext_id,OUT>::type rendezvous(IN input); + + /** + * Called by individual threads to synchronize execution and share + * state with the mingle function where each caller has a + * pre-defined participation id (enable by setting the external_id + * template flag). + * + * @return output parameter for a single thread + * @param input input parameter for a single thread + * @param my_id participant id for this thread (must be in range and + * not conflicting with other threads) + **/ + template <bool ext_id = external_id> + typename std::enable_if<ext_id,OUT>::type rendezvous(IN input, size_t my_id); }; } // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/util/rendezvous.hpp b/vespalib/src/vespa/vespalib/util/rendezvous.hpp index 2af5a55c8ab..284b536460a 100644 --- a/vespalib/src/vespa/vespalib/util/rendezvous.hpp +++ b/vespalib/src/vespa/vespalib/util/rendezvous.hpp @@ -1,52 +1,91 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "exceptions.h" +#include <cassert> namespace vespalib { -template <typename IN, typename OUT> -Rendezvous<IN, OUT>::Rendezvous(size_t n) +template <typename IN, typename OUT, bool external_id> +void +Rendezvous<IN, OUT, external_id>::meet_self(IN &input, OUT &output) { + _in[0] = &input; + _out[0] = &output; + mingle(); +} + +template <typename IN, typename OUT, bool external_id> +void +Rendezvous<IN, OUT, external_id>::meet_others(IN &input, OUT &output, size_t my_id, std::unique_lock<std::mutex> guard) +{ + if (external_id) { + assert(_in[my_id] == nullptr); + assert(_out[my_id] == nullptr); + } + _in[my_id] = &input; + _out[my_id] = &output; + if (++_next == _size) { + mingle(); + if (external_id) { + std::fill(_in.begin(), _in.end(), nullptr); + std::fill(_out.begin(), _out.end(), nullptr); + } + _next = 0; + ++_gen; + _cond.notify_all(); + } else { + size_t oldgen = _gen; + while (oldgen == _gen) { + _cond.wait(guard); + } + } +} + +template <typename IN, typename OUT, bool external_id> +Rendezvous<IN, OUT, external_id>::Rendezvous(size_t n) : _lock(), _cond(), _size(n), _next(0), _gen(0), - _in(n, 0), - _out(n, 0) + _in(n, nullptr), + _out(n, nullptr) { if (n == 0) { throw IllegalArgumentException("size must be greater than 0"); } } -template <typename IN, typename OUT> -Rendezvous<IN, OUT>::~Rendezvous() = default; +template <typename IN, typename OUT, bool external_id> +Rendezvous<IN, OUT, external_id>::~Rendezvous() = default; -template <typename IN, typename OUT> -OUT -Rendezvous<IN, OUT>::rendezvous(IN input) +template <typename IN, typename OUT, bool external_id> +template <bool ext_id> +typename std::enable_if<!ext_id,OUT>::type +Rendezvous<IN, OUT, external_id>::rendezvous(IN input) { - OUT ret = OUT(); + OUT ret{}; + static_assert(ext_id == external_id); if (_size == 1) { - _in[0] = &input; - _out[0] = &ret; - mingle(); + meet_self(input, ret); } else { std::unique_lock guard(_lock); - size_t me = _next++; - _in[me] = &input; - _out[me] = &ret; - if (_next == _size) { - mingle(); - _next = 0; - ++_gen; - _cond.notify_all(); - } else { - size_t oldgen = _gen; - while (oldgen == _gen) { - _cond.wait(guard); - } - } + meet_others(input, ret, _next, std::move(guard)); + } + return ret; +} + +template <typename IN, typename OUT, bool external_id> +template <bool ext_id> +typename std::enable_if<ext_id,OUT>::type +Rendezvous<IN, OUT, external_id>::rendezvous(IN input, size_t my_id) +{ + OUT ret{}; + assert(my_id < _size); + static_assert(ext_id == external_id); + if (_size == 1) { + meet_self(input, ret); + } else { + meet_others(input, ret, my_id, std::unique_lock(_lock)); } return ret; } |