diff options
90 files changed, 2448 insertions, 644 deletions
diff --git a/config-model/pom.xml b/config-model/pom.xml index 95e79fd09fb..c0751431d03 100644 --- a/config-model/pom.xml +++ b/config-model/pom.xml @@ -46,6 +46,11 @@ <scope>test</scope> </dependency> <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>${protobuf.version}</version> + </dependency> + <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <scope>provided</scope> @@ -498,6 +503,10 @@ <updateReleaseInfo>true</updateReleaseInfo> </configuration> </plugin> + <plugin> + <groupId>com.github.os72</groupId> + <artifactId>protoc-jar-maven-plugin</artifactId> + </plugin> </plugins> </build> </project> diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 4011ce43841..b153ff62e7d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -158,6 +159,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); } + // A reference to an ONNX model? + Optional<TensorType> onnxFeatureType = onnxFeatureType(reference); + if (onnxFeatureType.isPresent()) { + return onnxFeatureType.get(); + } + // A reference to a feature which returns a tensor? Optional<TensorType> featureTensorType = tensorFeatureType(reference); if (featureTensorType.isPresent()) { @@ -210,6 +217,26 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return Optional.of(function); } + private Optional<TensorType> onnxFeatureType(Reference reference) { + if ( ! reference.name().equals("onnxModel")) + return Optional.empty(); + + if ( ! featureTypes.containsKey(reference)) { + String configOrFileName = reference.arguments().expressions().get(0).toString(); + + // Look up standardized format as added in RankProfile + String modelConfigName = OnnxModelTransformer.getModelConfigName(reference); + String modelOutput = OnnxModelTransformer.getModelOutput(reference, null); + + reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); + if ( ! featureTypes.containsKey(reference)) { + throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'"); + } + } + + return Optional.of(featureTypes.get(reference)); + } + /** * There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet. * This returns the type of those features if this is a reference to either of them, or empty otherwise. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index c2fb2107604..5e8b8579ee6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -2,14 +2,22 @@ package com.yahoo.searchdefinition; import com.yahoo.config.FileReference; +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.utils.FileSender; +import onnx.Onnx; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.List; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -21,16 +29,16 @@ public class OnnxModel { public enum PathType {FILE, URI}; private final String name; + private PathType pathType = PathType.FILE; private String path = null; private String fileReference = ""; - private List<OnnxNameMapping> inputMap = new ArrayList<>(); - private List<OnnxNameMapping> outputMap = new ArrayList<>(); - - public PathType getPathType() { - return pathType; - } + private String defaultOutput = null; + private Map<String, String> inputMap = new HashMap<>(); + private Map<String, String> outputMap = new HashMap<>(); - private PathType pathType = PathType.FILE; + private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>(); + private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>(); + private Map<String, TensorType> vespaTypes = new HashMap<>(); public OnnxModel(String name) { this.name = name; @@ -49,21 +57,52 @@ public class OnnxModel { } public void setUri(String uri) { - Objects.requireNonNull(uri, "uri cannot be null"); - this.path = uri; - this.pathType = PathType.URI; + throw new IllegalArgumentException("URI for ONNX models are not currently supported"); + } + + public PathType getPathType() { + return pathType; + } + + public void setDefaultOutput(String onnxName) { + Objects.requireNonNull(onnxName, "Name cannot be null"); + this.defaultOutput = onnxName; } public void addInputNameMapping(String onnxName, String vespaName) { + addInputNameMapping(onnxName, vespaName, true); + } + + public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); + if (overwrite || ! inputMap.containsKey(onnxName)) { + inputMap.put(onnxName, vespaName); + } } public void addOutputNameMapping(String onnxName, String vespaName) { + addOutputNameMapping(onnxName, vespaName, true); + } + + public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); + if (overwrite || ! outputMap.containsKey(onnxName)) { + outputMap.put(onnxName, vespaName); + } + } + + public void addInputType(String onnxName, Onnx.TypeProto type) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(type, "Tensor type cannot be null"); + inputTypes.put(onnxName, type); + } + + public void addOutputType(String onnxName, Onnx.TypeProto type) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(type, "Tensor type cannot be null"); + outputTypes.put(onnxName, type); } /** Initiate sending of this constant to some services over file distribution */ @@ -76,11 +115,16 @@ public class OnnxModel { public String getName() { return name; } public String getFileName() { return path; } + public Path getFilePath() { return Path.fromString(path); } public String getUri() { return path; } public String getFileReference() { return fileReference; } - public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); } - public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); } + public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } + public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } + + public String getDefaultOutput() { + return defaultOutput; + } public void validate() { if (path == null || path.isEmpty()) @@ -90,23 +134,151 @@ public class OnnxModel { public String toString() { StringBuilder b = new StringBuilder(); b.append("onnx-model '").append(name) - .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) - .append("' with ref '").append(fileReference) - .append("'"); + .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) + .append("' with ref '").append(fileReference) + .append("'"); return b.toString(); } - public static class OnnxNameMapping { - private String onnxName; - private String vespaName; + /** + * Return the tensor type for an ONNX model for the given context. + * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output + * type depends on the input types for the given context (rank profile). + */ + public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) { + Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName); + if (onnxOutputType == null) { + throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'"); + } + if (allDimensionSizesAreKnown(onnxOutputType)) { + return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType)); + } + return getTensorTypeWithUnknownDimensions(onnxOutputType, context); + } + + private static boolean allDimensionSizesAreKnown(Onnx.TypeProto type) { + return type.getTensorType().getShape().getDimList().stream().noneMatch(d -> + (d.hasDimParam() && ! d.hasDimValue()) || d.getDimValue() == -1); + } + + private TensorType getTensorTypeWithUnknownDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) { + long unboundSize = 0; + Map<String, Long> symbolicSizes = new HashMap<>(); + + for (String onnxInputName : inputTypes.keySet()) { + Onnx.TypeProto onnxType = inputTypes.get(onnxInputName); + if (allDimensionSizesAreKnown(onnxType)) { + continue; + } + + Optional<TensorType> vespaType = resolveInputType(onnxInputName, context); + if (vespaType.isEmpty()) { + return TensorType.empty; + } + + var onnxDimensions = onnxType.getTensorType().getShape().getDimList(); + var vespaDimensions = vespaType.get().dimensions(); + if (vespaDimensions.size() != onnxDimensions.size()) { + return TensorType.empty; + } + + for (int i = 0; i < vespaDimensions.size(); ++i) { + if (vespaDimensions.get(i).size().isEmpty()) { + continue; + } + Long size = vespaDimensions.get(i).size().get(); + + // Handle dimensions with size -1 - typically batch dimensions + if (onnxDimensions.get(i).getDimValue() == -1) { + if (unboundSize != 0 && unboundSize != size) { + throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " + + "for type '" + onnxOutputType + "' in ONNX model '" + name + "'"); + } + unboundSize = size; + + // Handle dimensions with symbolic names + } else if (onnxDimensions.get(i).hasDimParam()) { + String symbolicName = onnxDimensions.get(i).getDimParam(); + if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { + throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" + + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'"); + } + symbolicSizes.put(symbolicName, size); + } + } + } + return typeFrom(onnxOutputType, symbolicSizes, unboundSize); + } + + private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) { + String source = inputMap.get(onnxInputName); + if (source != null) { + // Source is either a simple reference (query/attribute/constant)... + Optional<Reference> reference = Reference.simple(source); + if (reference.isPresent()) { + return Optional.of(context.getType(reference.get())); + } + // ... or a function + ExpressionFunction func = context.getFunction(source); + if (func != null) { + return Optional.of(func.getBody().type(context)); + } + } + return Optional.empty(); // if this context does not contain this input + } + + private static TensorType typeFrom(Onnx.TypeProto type) { + return typeFrom(type, null, 0); + } + + private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes, long unboundSize) { + String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... + Onnx.TensorShapeProto shape = type.getTensorType().getShape(); + TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType())); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); + long onnxDimensionSize = onnxDimension.getDimValue(); + if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) { + onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam()); + } + if (onnxDimensionSize == 0 && symbolicSizes != null) { + // This is for the case where all symbolic dimensions have + // different names, but can be resolved to a single dimension size. + Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values()); + if (unknownSizes.size() == 1) { + onnxDimensionSize = unknownSizes.iterator().next(); + } + } + if (onnxDimensionSize < 0) { + onnxDimensionSize = unboundSize; + } + if (onnxDimensionSize <= 0) { + throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " + + "ONNX type: " + type + " to Vespa tensor type."); + } + builder.indexed(dimensionName, onnxDimensionSize); + } + return builder.build(); + } - private OnnxNameMapping(String onnxName, String vespaName) { - this.onnxName = onnxName; - this.vespaName = vespaName; + private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT8: return TensorType.Value.FLOAT; + case INT16: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.FLOAT; + case UINT16: return TensorType.Value.FLOAT; + case UINT32: return TensorType.Value.FLOAT; + case UINT64: return TensorType.Value.FLOAT; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); } - public String getOnnxName() { return onnxName; } - public String getVespaName() { return vespaName; } - public void setVespaName(String vespaName) { this.vespaName = vespaName; } } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index d309f48d6df..96c043bdb34 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -18,6 +18,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.VespaModel; @@ -158,6 +159,10 @@ public class RankProfile implements Cloneable { return search != null ? search.rankingConstants() : model.rankingConstants(); } + private Map<String, OnnxModel> onnxModels() { + return search != null ? search.onnxModels().asMap() : Collections.emptyMap(); + } + private Stream<ImmutableSDField> allFields() { if (search == null) return Stream.empty(); if (allFieldsList == null) { @@ -821,6 +826,20 @@ public class RankProfile implements Cloneable { } } + // Add output types for ONNX models + for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) { + String modelName = entry.getKey(); + OnnxModel model = entry.getValue(); + Arguments args = new Arguments(new ReferenceNode(modelName)); + + TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context); + context.setType(new Reference("onnxModel", args, null), defaultOutputType); + + for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) { + TensorType type = model.getTensorType(mapping.getKey(), context); + context.setType(new Reference("onnxModel", args, mapping.getValue()), type); + } + } return context; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 84442fedc48..22a32c8fd65 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); modelBuilder.name(model.getName()); modelBuilder.fileref(model.getFileReference()); - model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName()))); - model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName()))); + model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); + model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); builder.model(modelBuilder); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 87eaaf0387a..56a5d539906 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set<String> functionNames = rankProfile.getFunctions().keySet(); if (functionNames.isEmpty()) return; for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) { - for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) { - String source = mapping.getVespaName(); + for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) { + String source = mapping.getValue(); if (functionNames.contains(source)) { - mapping.setVespaName("rankingExpression(" + source + ")"); + onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")"); } } } @@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>(); for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) { ReferenceNode referenceNode = i.next(); - ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch()); + ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile); if (referenceNode != replacedNode) { replacedSummaryFeatures.add(replacedNode); i.remove(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index ec517768ea9..d23a8376e7a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - if ( ! feature.getName().equals("onnx")) return feature; + if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature; try { FeatureArguments arguments = asFeatureArguments(feature.getArguments()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java index e1ad003e5bd..69cdae10e47 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -1,20 +1,36 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.ImmutableSearch; import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; +import com.yahoo.vespa.model.ml.ModelName; import java.util.List; /** - * Transforms instances of the onnxModel ranking feature and generates - * ONNX configuration if necessary. + * Transforms ONNX model features of the forms: + * + * onnxModel(config_name) + * onnxModel(config_name).output + * onnxModel("path/to/model") + * onnxModel("path/to/model").output + * onnxModel("path/to/model", "path/to/output") + * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused + * + * To the format expected by the backend: + * + * onnxModel(config_name).output * * @author lesters */ @@ -33,85 +49,92 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { if (context.rankProfile() == null) return feature; if (context.rankProfile().getSearch() == null) return feature; - return transformFeature(feature, context.rankProfile().getSearch()); + return transformFeature(feature, context.rankProfile()); } - public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) { - if (!feature.getName().equals("onnxModel")) return feature; + public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) { + ImmutableSearch search = rankProfile.getSearch(); + final String featureName = feature.getName(); + if ( ! featureName.equals("onnxModel")) return feature; Arguments arguments = feature.getArguments(); if (arguments.isEmpty()) - throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " + - "onnx-model config or a ONNX file."); - if (arguments.expressions().size() > 2) - throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); - - // Validation that the file actually exists is handled when the file is added to file distribution. - // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. - - String modelConfigName; - OnnxModel onnxModel; - if (arguments.expressions().get(0) instanceof ReferenceNode) { - modelConfigName = arguments.expressions().get(0).toString(); - onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { - throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found"); - } - } else if (arguments.expressions().get(0) instanceof ConstantNode) { + throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " + + "onnx-model config or an ONNX file."); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments."); + + // Check that the model configuration "onnx-model" exists. If not defined, it should have been added + // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find + // the actual ONNX file, which can happen if we are restarting or upgrading an application using an + // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store. + + String modelConfigName = getModelConfigName(feature.reference()); + OnnxModel onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { String path = asString(arguments.expressions().get(0)); - modelConfigName = asValidIdentifier(path); - onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { - onnxModel = new OnnxModel(modelConfigName, path); - search.onnxModels().add(onnxModel); - } - } else { - throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'"); + ModelName modelName = new ModelName(null, Path.fromString(path), true); + ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile); + FeatureArguments featureArguments = new FeatureArguments(arguments); + return convertedModel.expression(featureArguments, null); } - String output = null; - if (feature.getOutput() != null) { - output = feature.getOutput(); - if ( ! hasOutputMapping(onnxModel, output)) { - onnxModel.addOutputNameMapping(output, output); + String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput()); + String output = getModelOutput(feature.reference(), defaultOutput); + if (! onnxModel.getOutputMap().containsValue(output)) { + throw new IllegalArgumentException(featureName + " argument '" + output + + "' output not found in model '" + onnxModel.getFileName() + "'"); + } + return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output); + } + + public static String getModelConfigName(Reference reference) { + if (reference.arguments().size() > 0) { + ExpressionNode expr = reference.arguments().expressions().get(0); + if (expr instanceof ReferenceNode) { // refers to onnx-model config + return expr.toString(); } - } else if (arguments.expressions().size() > 1) { - String name = asString(arguments.expressions().get(1)); - output = asValidIdentifier(name); - if ( ! hasOutputMapping(onnxModel, output)) { - onnxModel.addOutputNameMapping(name, output); + if (expr instanceof ConstantNode) { // refers to an file path + return asValidIdentifier(expr); } } + return null; + } - // Replace feature with name of config - ExpressionNode argument = new ReferenceNode(modelConfigName); - return new ReferenceNode("onnxModel", List.of(argument), output); - + public static String getModelOutput(Reference reference, String defaultOutput) { + if (reference.output() != null) { + return reference.output(); + } else if (reference.arguments().expressions().size() == 2) { + return asValidIdentifier(reference.arguments().expressions().get(1)); + } else if (reference.arguments().expressions().size() > 2) { + return asValidIdentifier(reference.arguments().expressions().get(2)); + } + return defaultOutput; } - private static boolean hasOutputMapping(OnnxModel onnxModel, String as) { - return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as)); + public static String stripQuotes(String s) { + if (isNotQuoteSign(s.codePointAt(0))) return s; + if (isNotQuoteSign(s.codePointAt(s.length() - 1))) + throw new IllegalArgumentException("argument [" + s + "] is missing end quote"); + return s.substring(1, s.length()-1); } - private static String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); + public static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); } - private static String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); + private static String asValidIdentifier(ExpressionNode node) { + return asValidIdentifier(asString(node)); } - private static boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; + private static boolean isNotQuoteSign(int c) { + return c != '\'' && c != '"'; } - private static String asValidIdentifier(String str) { - return str.replaceAll("[^\\w\\d\\$@_]", "_"); + public static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java new file mode 100644 index 00000000000..afba88c135d --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java @@ -0,0 +1,97 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.vespa.model.container.search.QueryProfiles; + +import java.util.Map; + +/** + * Processes ONNX ranking features of the form: + * + * onnx("files/model.onnx", "path/to/output:1") + * + * And generates an "onnx-model" configuration as if it was defined in the schema: + * + * onnx-model files_model_onnx { + * file: "files/model.onnx" + * } + * + * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be + * processed after this. + * + * @author lesters + */ +public class OnnxModelConfigGenerator extends Processor { + + public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { + super(search, deployLogger, rankProfileRegistry, queryProfiles); + } + + @Override + public void process(boolean validate, boolean documentsOnly) { + if (documentsOnly) return; + for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { + if (profile.getFirstPhaseRanking() != null) { + process(profile.getFirstPhaseRanking().getRoot()); + } + if (profile.getSecondPhaseRanking() != null) { + process(profile.getSecondPhaseRanking().getRoot()); + } + for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) { + process(function.getValue().function().getBody().getRoot()); + } + for (ReferenceNode feature : profile.getSummaryFeatures()) { + process(feature); + } + } + } + + private void process(ExpressionNode node) { + if (node instanceof ReferenceNode) { + process((ReferenceNode)node); + } else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode) node).children()) { + process(child); + } + } + } + + private void process(ReferenceNode feature) { + if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) { + if (feature.getArguments().size() > 0) { + if (feature.getArguments().expressions().get(0) instanceof ConstantNode) { + ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0); + String path = OnnxModelTransformer.stripQuotes(node.sourceString()); + String modelConfigName = OnnxModelTransformer.asValidIdentifier(path); + + // Only add the configuration if the model can actually be found. + if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { + path = ApplicationPackage.MODELS_DIR.append(path).toString(); + if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { + return; + } + } + + OnnxModel onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + onnxModel = new OnnxModel(modelConfigName, path); + search.onnxModels().add(onnxModel); + } + } + } + } + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java new file mode 100644 index 00000000000..bead2e7e7c9 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java @@ -0,0 +1,154 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchdefinition.processing; + +import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.component.Version; +import com.yahoo.config.FileReference; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.path.Path; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; +import com.yahoo.vespa.defaults.Defaults; +import com.yahoo.vespa.model.container.search.QueryProfiles; +import onnx.Onnx; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Paths; +import java.util.Map; +import java.util.Optional; + +/** + * Processes every "onnx-model" element in the schema. Parses the model file, + * adds missing input and output mappings (assigning default names), and + * adds tensor types to all model inputs and outputs. + * + * Must be processed before RankingExpressingTypeResolver. + * + * @author lesters + */ +public class OnnxModelTypeResolver extends Processor { + + public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { + super(search, deployLogger, rankProfileRegistry, queryProfiles); + } + + @Override + public void process(boolean validate, boolean documentsOnly) { + if (documentsOnly) return; + + for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) { + OnnxModel modelConfig = entry.getValue(); + try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + + // Model inputs - if not defined, assumes a function is provided with a valid name + for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { + String onnxInputName = valueInfo.getName(); + String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName); + modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false); + modelConfig.addInputType(onnxInputName, valueInfo.getType()); + } + + // Model outputs + for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) { + String onnxOutputName = valueInfo.getName(); + String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName); + modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false); + modelConfig.addOutputType(onnxOutputName, valueInfo.getType()); + } + + // Set the first output as default + if ( ! model.getGraph().getOutputList().isEmpty()) { + modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName()); + } + + } catch (IOException e) { + throw new IllegalArgumentException("Unable to parse ONNX model", e); + } + } + } + + static boolean modelFileExists(String path, ApplicationPackage app) { + Path pathInApplicationPackage = Path.fromString(path); + if (getFile(pathInApplicationPackage, app).exists()) { + return true; + } + if (getFileReference(pathInApplicationPackage, app).isPresent()) { + return true; + } + return false; + } + + private InputStream openModelFile(Path path) throws FileNotFoundException { + ApplicationFile file; + Optional<FileReference> reference; + Path modelsPath = ApplicationPackage.MODELS_DIR.append(path); + + if ((file = getFile(path)).exists()) { + return file.createInputStream(); + } + if ((file = getFile(modelsPath)).exists()) { + return file.createInputStream(); + } + if ((reference = getFileReference(path)).isPresent()) { + return openFromFileRepository(path, reference.get()); + } + if ((reference = getFileReference(modelsPath)).isPresent()) { + return openFromFileRepository(modelsPath, reference.get()); + } + + throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " + + "in application package or file repository."); + } + + private ApplicationFile getFile(Path path) { + return getFile(path, search.applicationPackage()); + } + + private static ApplicationFile getFile(Path path, ApplicationPackage app) { + return app.getFile(path); + } + + private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException { + return new FileInputStream(new File(getFileRepositoryPath(path, reference.value()))); + } + + public static String getFileRepositoryPath(Path path, String fileReference) { + ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults + String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); + return Paths.get(fileRefDir, fileReference, path.getName()).toString(); + } + + private Optional<FileReference> getFileReference(Path path) { + return getFileReference(path, search.applicationPackage()); + } + + private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) { + Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app); + if (fileRegistry.isPresent()) { + for (FileRegistry.Entry file : fileRegistry.get().export()) { + if (file.relativePath.equals(path.toString())) { + return Optional.of(file.reference); + } + } + } + return Optional.empty(); + } + + private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) { + if (app == null) return Optional.empty(); + Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo); + return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get())); + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java index e8594c2a87f..1a3ef9e54b4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java @@ -74,6 +74,8 @@ public class Processing { ReferenceFieldsProcessor::new, FastAccessValidator::new, ReservedFunctionNames::new, + OnnxModelConfigGenerator::new, + OnnxModelTypeResolver::new, RankingExpressionTypeResolver::new, // These should be last: IndexingValidation::new, diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index c6c7969e466..d5c5183b01f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -1,20 +1,19 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation; -import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.IOUtils; import com.yahoo.log.InvalidLogFormatException; import java.util.logging.Level; import com.yahoo.log.LogMessage; import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver; import com.yahoo.yolean.Exceptions; import com.yahoo.system.ProcessExecuter; import com.yahoo.text.StringUtilities; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.collections.Pair; import com.yahoo.config.ConfigInstance; -import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.config.search.ImportedFieldsConfig; import com.yahoo.vespa.config.search.IndexschemaConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -31,7 +30,6 @@ import com.yahoo.vespa.model.search.SearchCluster; import java.io.File; import java.io.IOException; import java.nio.file.Files; -import java.nio.file.Paths; import java.time.Duration; import java.time.Instant; import java.util.logging.Logger; @@ -152,12 +150,9 @@ public class RankSetupValidator extends Validator { // Assist verify-ranksetup in finding the actual ONNX model files Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap(); if (models.values().size() > 0) { - ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults - String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); List<String> config = new ArrayList<>(models.values().size() * 2); for (OnnxModel model : models.values()) { - String modelFilename = Paths.get(model.getFileName()).getFileName().toString(); - String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString(); + String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference()); config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference())); config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath)); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 943fcbf6c1d..5ee6ed02e61 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -150,7 +150,7 @@ public class ConvertedModel { */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { ExpressionFunction expression = selectExpression(arguments); - if (sourceModel.isPresent()) // we should verify + if (sourceModel.isPresent() && context != null) // we should verify verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); return expression.getBody().getRoot(); } diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto new file mode 100644 index 00000000000..dc6542867e0 --- /dev/null +++ b/config-model/src/main/protobuf/onnx.proto @@ -0,0 +1,464 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// Copyright (c) Facebook Inc. and Microsoft Corporation. +// Licensed under the MIT license. + +syntax = "proto2"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. We should use version as + // xx(major) - xx(minor) - xxxx(bugfix) + // and we are starting with 0x00000001 (0.0.1), which was the + // version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x00000001; + + // IR_VERSION 0.0.2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x00000002; + + // IR VERSION 0.0.3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION = 0x00000003; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + } + + // The name field MUST be present for this version of the IR. + optional string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + optional string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + optional AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional float f = 2; // float + optional int64 i = 3; // int + optional bytes s = 4; // UTF-8 string + optional TensorProto t = 5; // tensor value + optional GraphProto g = 6; // graph + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + optional string name = 1; // namespace Value + // This field MUST be present in this version of the IR. + optional TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + optional string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + optional string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + optional string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + optional string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + optional int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 4; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + optional string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + optional string key = 1; + optional string value= 2; +}; + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // also appears in the input list. + repeated TensorProto initializer = 5; + + // A human-readable documentation for this graph. Markdown is allowed. + optional string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // Advanced types + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + optional DataType data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + optional int64 begin = 1; + optional int64 end = 2; + } + optional Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + optional string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + optional string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + optional bytes raw_data = 9; + + // For double + // Complex64 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + optional string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// A set of pre-defined constants to be used as values for +// the standard denotation field in TensorShapeProto.Dimension +// for semantic description of the tensor dimension. +message DenotationConstProto { + // Describe a batch number dimension. + optional string DATA_BATCH = 1 [default = "DATA_BATCH"]; + // Describe a channel dimension. + optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"]; + // Describe a time dimension. + optional string DATA_TIME = 3 [default = "DATA_TIME"]; + // Describe a feature dimension. This is typically a feature + // dimension in RNN and/or spatial dimension in CNN. + optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"]; + // Describe a filter in-channel dimension. This is the dimension + // that is identical (in size) to the channel dimension of the input + // image feature maps. + optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"]; + // Describe a filter out channel dimension. This is the dimension + // that is identical (int size) to the channel dimension of the output + // image feature maps. + optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"]; + // Describe a filter spatial dimension. + optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"]; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST be present for this version of the IR. + optional TensorProto.DataType elem_type = 1; + optional TensorShapeProto shape = 2; + } + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + } +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + optional string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + optional int64 version = 2; +}
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd index e9575af6010..cc73f2daff5 100644 --- a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd +++ b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd @@ -18,7 +18,7 @@ search test { } function mnist_softmax_onnx() { - expression: onnx("mnist_softmax") + expression: onnx_vespa("mnist_softmax") } function my_xgboost() { diff --git a/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py new file mode 100755 index 00000000000..55df3a557e9 --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py @@ -0,0 +1,12 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "sequence"]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "sequence"]) + +nodes = [helper.make_node('Identity', ['input'], ['output'])] +graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT]) +model_def = helper.make_model(graph_def, producer_name='create_dynamic_model.py') +onnx.save(model_def, 'dynamic_model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/create_model.py b/config-model/src/test/integration/onnx-model/files/create_model.py new file mode 100755 index 00000000000..10ff92c2eda --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/create_model.py @@ -0,0 +1,37 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT_1 = helper.make_tensor_value_info('first_input', TensorProto.FLOAT, [2]) +INPUT_2 = helper.make_tensor_value_info('second/input:0', TensorProto.FLOAT, [2]) +INPUT_3 = helper.make_tensor_value_info('third_input', TensorProto.FLOAT, [2]) +OUTPUT_1 = helper.make_tensor_value_info('path/to/output:0', TensorProto.FLOAT, [2]) +OUTPUT_2 = helper.make_tensor_value_info('path/to/output:1', TensorProto.FLOAT, [2]) +OUTPUT_3 = helper.make_tensor_value_info('path/to/output:2', TensorProto.FLOAT, [2]) + +nodes = [ + helper.make_node( + 'Add', + ['first_input', 'second/input:0'], + ['path/to/output:0'], + ), + helper.make_node( + 'Add', + ['third_input', 'second/input:0'], + ['path/to/output:1'] + ), + helper.make_node( + 'Add', + ['path/to/output:0', 'path/to/output:1'], + ['path/to/output:2'] + ), +] +graph_def = helper.make_graph( + nodes, + 'simple_scoring', + [INPUT_1, INPUT_2, INPUT_3], + [OUTPUT_1, OUTPUT_2, OUTPUT_3] +) +model_def = helper.make_model(graph_def, producer_name='create_model.py') +onnx.save(model_def, 'model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/create_unbound_model.py b/config-model/src/test/integration/onnx-model/files/create_unbound_model.py new file mode 100755 index 00000000000..abf733ea43f --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/create_unbound_model.py @@ -0,0 +1,12 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, [-1, 2]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [-1, 2]) + +nodes = [helper.make_node('Identity', ['input'], ['output'])] +graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT]) +model_def = helper.make_model(graph_def, producer_name='create_unbound_model.py') +onnx.save(model_def, 'unbound_model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx new file mode 100644 index 00000000000..6bbdad2d76e --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx @@ -0,0 +1,13 @@ +create_dynamic_model.py:x + +inputoutput"Identitysimple_scoringZ$ +input + +batch + +sequenceb% +output + +batch + +sequenceB
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/files/model.onnx b/config-model/src/test/integration/onnx-model/files/model.onnx new file mode 100644 index 00000000000..f3898205c6a --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/model.onnx @@ -0,0 +1,34 @@ +create_model.py:í +4 +first_input +second/input:0path/to/output:0"Add +4 +third_input +second/input:0path/to/output:1"Add +; +path/to/output:0 +path/to/output:1path/to/output:2"Addsimple_scoringZ +first_input + + +Z +second/input:0 + + +Z +third_input + + +b +path/to/output:0 + + +b +path/to/output:1 + + +b +path/to/output:2 + + +B
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/files/summary_model.onnx b/config-model/src/test/integration/onnx-model/files/summary_model.onnx new file mode 100644 index 00000000000..f3898205c6a --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/summary_model.onnx @@ -0,0 +1,34 @@ +create_model.py:í +4 +first_input +second/input:0path/to/output:0"Add +4 +third_input +second/input:0path/to/output:1"Add +; +path/to/output:0 +path/to/output:1path/to/output:2"Addsimple_scoringZ +first_input + + +Z +second/input:0 + + +Z +third_input + + +b +path/to/output:0 + + +b +path/to/output:1 + + +b +path/to/output:2 + + +B
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/files/unbound_model.onnx b/config-model/src/test/integration/onnx-model/files/unbound_model.onnx new file mode 100644 index 00000000000..155b3125256 --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/unbound_model.onnx @@ -0,0 +1,11 @@ +create_unbound_model.py:p + +inputoutput"Identitysimple_scoringZ +input + +ÿÿÿÿÿÿÿÿÿ +b! +output + +ÿÿÿÿÿÿÿÿÿ +B
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd index 0f0fa694e6f..a87222e77ee 100644 --- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd +++ b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd @@ -14,7 +14,7 @@ search test { } onnx-model my_model { - file: files/ranking_model.onnx + file: files/model.onnx input first_input: attribute(document_field) input "second/input:0": constant(my_constant) input "third_input": my_function @@ -22,19 +22,31 @@ search test { } onnx-model another_model { - file: files/ranking_model.onnx + file: files/model.onnx input first_input: attribute(document_field) input "second/input:0": constant(my_constant) input "third_input": another_function output "path/to/output:2": out } + onnx-model dynamic_model { + file: files/dynamic_model.onnx + input input: my_function + output output: my_output + } + + onnx-model unbound_model { + file: files/unbound_model.onnx + input input: my_function + output output: my_output + } + rank-profile test_model_config { function my_function() { expression: tensor(d0[2])(1) } first-phase { - expression: onnxModel(my_model).out + expression: onnxModel(my_model).out{d0:1} } } @@ -49,7 +61,7 @@ search test { expression: my_function() } first-phase { - expression: onnxModel("files/ranking_model.onnx", "path/to/output:1") + expression: onnxModel("files/model.onnx", "path/to/output:1"){d0:1} } } @@ -62,9 +74,39 @@ search test { } summary-features { onnxModel(another_model).out - onnxModel("files/ranking_model.onnx", "path/to/output:2") + onnxModel("files/summary_model.onnx", "path/to/output:2") } + } + rank-profile test_dynamic_model { + function my_function() { + expression: tensor(d0[1],d1[2])(d1) + } + first-phase { + expression: onnxModel(dynamic_model){d0:0,d1:1} + } } + rank-profile test_dynamic_model_2 { + function my_function_2() { + expression: tensor(d0[1],d1[3])(d1) + } + function my_function() { + expression: my_function_2() + } + first-phase { + expression: onnxModel(dynamic_model){d0:0,d1:2} + } + } + + rank-profile test_unbound_model { + function my_function() { + expression: tensor(d0[1],d1[2])(d1) + } + first-phase { + expression: onnxModel(unbound_model){d0:0,d1:1} + } + } + + } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index d9b0c70dfdd..4eb8681c374 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -25,43 +25,68 @@ public class RankingExpressionWithOnnxModelTestCase { OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder(); ((OnnxModelsConfig.Producer) db).getConfig(builder); OnnxModelsConfig config = new OnnxModelsConfig(builder); - assertEquals(3, config.model().size()); + assertEquals(6, config.model().size()); - assertEquals("my_model", config.model(1).name()); + assertEquals("my_model", config.model(0).name()); + assertEquals(3, config.model(0).input().size()); + assertEquals("second/input:0", config.model(0).input(0).name()); + assertEquals("constant(my_constant)", config.model(0).input(0).source()); + assertEquals("first_input", config.model(0).input(1).name()); + assertEquals("attribute(document_field)", config.model(0).input(1).source()); + assertEquals("third_input", config.model(0).input(2).name()); + assertEquals("rankingExpression(my_function)", config.model(0).input(2).source()); + assertEquals(3, config.model(0).output().size()); + assertEquals("path/to/output:0", config.model(0).output(0).name()); + assertEquals("out", config.model(0).output(0).as()); + assertEquals("path/to/output:1", config.model(0).output(1).name()); + assertEquals("path_to_output_1", config.model(0).output(1).as()); + assertEquals("path/to/output:2", config.model(0).output(2).name()); + assertEquals("path_to_output_2", config.model(0).output(2).as()); + + assertEquals("files_model_onnx", config.model(1).name()); assertEquals(3, config.model(1).input().size()); - assertEquals("first_input", config.model(1).input(0).name()); - assertEquals("attribute(document_field)", config.model(1).input(0).source()); - assertEquals("second/input:0", config.model(1).input(1).name()); - assertEquals("constant(my_constant)", config.model(1).input(1).source()); - assertEquals("third_input", config.model(1).input(2).name()); - assertEquals("rankingExpression(my_function)", config.model(1).input(2).source()); - assertEquals(1, config.model(1).output().size()); + assertEquals(3, config.model(1).output().size()); assertEquals("path/to/output:0", config.model(1).output(0).name()); - assertEquals("out", config.model(1).output(0).as()); - - assertEquals("files_ranking_model_onnx", config.model(0).name()); - assertEquals(0, config.model(0).input().size()); - assertEquals(2, config.model(0).output().size()); - assertEquals("path/to/output:1", config.model(0).output(0).name()); - assertEquals("path_to_output_1", config.model(0).output(0).as()); - assertEquals("path/to/output:2", config.model(0).output(1).name()); - assertEquals("path_to_output_2", config.model(0).output(1).as()); + assertEquals("path_to_output_0", config.model(1).output(0).as()); + assertEquals("path/to/output:1", config.model(1).output(1).name()); + assertEquals("path_to_output_1", config.model(1).output(1).as()); + assertEquals("path/to/output:2", config.model(1).output(2).name()); + assertEquals("path_to_output_2", config.model(1).output(2).as()); + assertEquals("files_model_onnx", config.model(1).name()); assertEquals("another_model", config.model(2).name()); assertEquals("third_input", config.model(2).input(2).name()); assertEquals("rankingExpression(another_function)", config.model(2).input(2).source()); + + assertEquals("files_summary_model_onnx", config.model(3).name()); + assertEquals(3, config.model(3).input().size()); + assertEquals(3, config.model(3).output().size()); + + assertEquals("dynamic_model", config.model(5).name()); + assertEquals(1, config.model(5).input().size()); + assertEquals(1, config.model(5).output().size()); + assertEquals("rankingExpression(my_function)", config.model(5).input(0).source()); + + assertEquals("unbound_model", config.model(4).name()); + assertEquals(1, config.model(4).input().size()); + assertEquals(1, config.model(4).output().size()); + assertEquals("rankingExpression(my_function)", config.model(4).input(0).source()); + + } private void assertTransformedFeature(DocumentDatabase db) { RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); ((RankProfilesConfig.Producer) db).getConfig(builder); RankProfilesConfig config = new RankProfilesConfig(builder); - assertEquals(5, config.rankprofile().size()); + assertEquals(8, config.rankprofile().size()); assertEquals("test_model_config", config.rankprofile(2).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name()); assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name()); - assertEquals("onnxModel(my_model).out", config.rankprofile(2).fef().property(2).value()); + assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name()); + assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value()); assertEquals("test_generated_model_config", config.rankprofile(3).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name()); @@ -69,16 +94,34 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("rankingExpression(second_input).rankingScript", config.rankprofile(3).fef().property(4).name()); assertEquals("rankingExpression(third_input).rankingScript", config.rankprofile(3).fef().property(6).name()); assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name()); - assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_1", config.rankprofile(3).fef().property(8).value()); + assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name()); + assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value()); assertEquals("test_summary_features", config.rankprofile(4).name()); assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name()); assertEquals("1", config.rankprofile(4).fef().property(3).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name()); - assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(4).value()); + assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name()); - assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value()); + assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value()); + + assertEquals("test_dynamic_model", config.rankprofile(5).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name()); + assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value()); + + assertEquals("test_dynamic_model_2", config.rankprofile(6).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name()); + assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); + + assertEquals("test_unbound_model", config.rankprofile(7).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(7).fef().property(3).name()); + assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(7).fef().property(3).value()); + + } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 6bf69907609..40bf970a313 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -70,7 +70,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -87,7 +87,7 @@ public class RankingExpressionWithOnnxTestCase { queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("query(mytensor)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", null, null, "Placeholder", @@ -99,7 +99,7 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithDocumentFeature() { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", null, "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", @@ -117,7 +117,7 @@ public class RankingExpressionWithOnnxTestCase { "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", @@ -129,21 +129,21 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testNestedOnnxReference() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "5 + sum(onnx('mnist_softmax.onnx'))"); + "5 + sum(onnx_vespa('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutput() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'layer_add')"); + "onnx_vespa('mnist_softmax.onnx', 'layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutputAndSignature() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'default.layer_add')"); + "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -155,7 +155,7 @@ public class RankingExpressionWithOnnxTestCase { new QueryProfileRegistry(), " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: onnx('mnist_softmax.onnx')" + + " expression: onnx_vespa('mnist_softmax.onnx')" + " }\n" + " }"); search.compileRankProfile("my_profile", applicationDir.append("models")); @@ -164,7 +164,7 @@ public class RankingExpressionWithOnnxTestCase { } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx('mnist_softmax.onnx'): " + + "onnx_vespa('mnist_softmax.onnx'): " + "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); @@ -175,13 +175,13 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithWrongFunctionType() { try { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)", - "onnx('mnist_softmax.onnx')"); + "onnx_vespa('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx('mnist_softmax.onnx'): " + + "onnx_vespa('mnist_softmax.onnx'): " + "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " + "but this function returns tensor(d0[1],d5[10])", Exceptions.toMessageString(expected)); @@ -192,13 +192,13 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceSpecifyingNonExistingOutput() { try { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'y')"); + "onnx_vespa('mnist_softmax.onnx', 'y')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx('mnist_softmax.onnx','y'): " + + "onnx_vespa('mnist_softmax.onnx','y'): " + "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add", Exceptions.toMessageString(expected)); } @@ -207,7 +207,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testImportingFromStoredExpressions() throws IOException { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx')"); + "onnx_vespa('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir @@ -218,7 +218,7 @@ public class RankingExpressionWithOnnxTestCase { storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); RankProfileSearchFixture searchFromStored = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", null, null, "Placeholder", @@ -243,7 +243,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor<float>(d1[10],d2[784])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx('mnist_softmax.onnx')" + + " expression: onnx_vespa('mnist_softmax.onnx')" + " }\n" + " }" + " rank-profile my_profile_child inherits my_profile {\n" + @@ -288,7 +288,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor<float>(d0[3])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx('" + name + ".onnx')" + + " expression: onnx_vespa('" + name + ".onnx')" + " }\n" + " }"; final String functionName = "imported_ml_function_" + name + "_exp_output"; @@ -310,7 +310,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor<float>(d0[3])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx('" + name + ".onnx')" + + " expression: onnx_vespa('" + name + ".onnx')" + " }\n" + " }" + " rank-profile my_profile_child inherits my_profile {\n" + diff --git a/configdefinitions/src/vespa/stor-filestor.def b/configdefinitions/src/vespa/stor-filestor.def index bf1b4294b5b..1cec77832a7 100644 --- a/configdefinitions/src/vespa/stor-filestor.def +++ b/configdefinitions/src/vespa/stor-filestor.def @@ -63,3 +63,11 @@ enable_merge_local_node_choose_docs_optimalization bool default=true restart ## if splitting is expensive, but listing document identifiers is fairly cheap. ## This is true for memfile persistence layer, but not for vespa search. enable_multibit_split_optimalization bool default=true restart + +## Whether or not to use async message handling when scheduling storage messages from FileStorManager. +## +## When turned on, the calling thread (e.g. FNET network thread when using Storage API RPC) +## gets the next async message to handle (if any) as part of scheduling a storage message. +## This async message is then handled by the calling thread immediately, +## instead of going via a persistence thread. +use_async_message_handling_on_schedule bool default=false restart diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java index a522e26a46d..7b4d82a9f53 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java @@ -18,6 +18,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueHandl import com.yahoo.vespa.hosted.controller.api.integration.organization.Mailer; import com.yahoo.vespa.hosted.controller.api.integration.organization.OwnershipIssues; import com.yahoo.vespa.hosted.controller.api.integration.organization.SystemMonitor; +import com.yahoo.vespa.hosted.controller.api.integration.repair.HostRepairClient; import com.yahoo.vespa.hosted.controller.api.integration.resource.CostReportConsumer; import com.yahoo.vespa.hosted.controller.api.integration.resource.MeteringClient; import com.yahoo.vespa.hosted.controller.api.integration.routing.GlobalRoutingService; @@ -79,4 +80,5 @@ public interface ServiceRegistry { BillingController billingController(); + HostRepairClient hostRepairClient(); } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/ZmsClientMock.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/ZmsClientMock.java index ec5d62569f6..942f0f35f58 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/ZmsClientMock.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/ZmsClientMock.java @@ -136,6 +136,17 @@ public class ZmsClientMock implements ZmsClient { } @Override + public void addPolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole) { + + } + + @Override + public boolean deletePolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole) { + return false; + } + + + @Override public void close() {} private static AthenzDomain getTenantDomain(AthenzResourceName resource) { diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/Node.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/Node.java index 07e411cd5cd..b57b2dbc496 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/Node.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/Node.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.configserver; +import com.fasterxml.jackson.databind.JsonNode; import com.yahoo.component.Version; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.DockerImage; @@ -10,6 +11,8 @@ import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.TenantName; import java.time.Instant; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -48,13 +51,14 @@ public class Node { private final boolean wantToRetire; private final boolean wantToDeprovision; private final Optional<TenantName> reservedTo; + private final Map<String, JsonNode> reports; public Node(HostName hostname, Optional<HostName> parentHostname, State state, NodeType type, NodeResources resources, Optional<ApplicationId> owner, Version currentVersion, Version wantedVersion, Version currentOsVersion, Version wantedOsVersion, Optional<Instant> currentFirmwareCheck, Optional<Instant> wantedFirmwareCheck, ServiceState serviceState, Optional<Instant> suspendedSince, long restartGeneration, long wantedRestartGeneration, long rebootGeneration, long wantedRebootGeneration, int cost, String flavor, String clusterId, ClusterType clusterType, boolean wantToRetire, boolean wantToDeprovision, - Optional<TenantName> reservedTo, DockerImage wantedDockerImage, DockerImage currentDockerImage) { + Optional<TenantName> reservedTo, DockerImage wantedDockerImage, DockerImage currentDockerImage, Map<String, JsonNode> reports) { this.hostname = hostname; this.parentHostname = parentHostname; this.state = state; @@ -82,6 +86,7 @@ public class Node { this.reservedTo = reservedTo; this.wantedDockerImage = wantedDockerImage; this.currentDockerImage = currentDockerImage; + this.reports = reports; } public HostName hostname() { @@ -188,6 +193,10 @@ public class Node { public Optional<TenantName> reservedTo() { return reservedTo; } + public Map<String, JsonNode> reports() { + return reports; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -258,6 +267,7 @@ public class Node { private boolean wantToRetire; private boolean wantToDeprovision; private Optional<TenantName> reservedTo = Optional.empty(); + private Map<String, JsonNode> reports = new HashMap<>(); public Builder() { } @@ -289,6 +299,7 @@ public class Node { this.wantToRetire = node.wantToRetire; this.wantToDeprovision = node.wantToDeprovision; this.reservedTo = node.reservedTo; + this.reports = node.reports; } public Builder hostname(HostName hostname) { @@ -431,7 +442,7 @@ public class Node { currentOsVersion, wantedOsVersion, currentFirmwareCheck, wantedFirmwareCheck, serviceState, suspendedSince, restartGeneration, wantedRestartGeneration, rebootGeneration, wantedRebootGeneration, cost, flavor, clusterId, clusterType, wantToRetire, wantToDeprovision, reservedTo, - wantedDockerImage, currentDockerImage); + wantedDockerImage, currentDockerImage, reports); } } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/NodeRepository.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/NodeRepository.java index aebfab7cbff..6f4b39ac9b9 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/NodeRepository.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/NodeRepository.java @@ -90,6 +90,8 @@ public interface NodeRepository { void retireAndDeprovision(ZoneId zoneId, String hostName); + void patchNode(ZoneId zoneId, String hostName, NodeRepositoryNode node); + private static Node toNode(NodeRepositoryNode node) { var application = Optional.ofNullable(node.getOwner()) .map(owner -> ApplicationId.from(owner.getTenant(), owner.getApplication(), @@ -128,7 +130,8 @@ public interface NodeRepository { node.getWantToDeprovision(), Optional.ofNullable(node.getReservedTo()).map(TenantName::from), dockerImageFrom(node.getWantedDockerImage()), - dockerImageFrom(node.getCurrentDockerImage())); + dockerImageFrom(node.getCurrentDockerImage()), + node.getReports()); } private static String clusterIdOf(NodeMembership nodeMembership) { diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/HostRepairClient.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/HostRepairClient.java new file mode 100644 index 00000000000..a4a5a773cb9 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/HostRepairClient.java @@ -0,0 +1,23 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.integration.repair; + +import com.yahoo.config.provision.HostName; +import com.yahoo.config.provision.zone.ZoneApi; +import com.yahoo.config.provision.zone.ZoneId; +import com.yahoo.vespa.hosted.controller.api.integration.configserver.Node; + +import java.util.List; +import java.util.Map; + +/** + * @author olaa + */ +public interface HostRepairClient { + + /* Checks current ticket status and takes appropriate action */ + void updateRepairStatus(ZoneApi zone, Map<Node, RepairTicketReport> nodes); + + /* Creates reparation ticket for given host. Returns ticket number */ + String createTicket(HostName hostname, String colo, ZoneId zoneId, String description, String category); + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/MockRepairClient.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/MockRepairClient.java new file mode 100644 index 00000000000..6ceceda5712 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/MockRepairClient.java @@ -0,0 +1,33 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.integration.repair; + +import com.yahoo.config.provision.HostName; +import com.yahoo.config.provision.zone.ZoneApi; +import com.yahoo.config.provision.zone.ZoneId; +import com.yahoo.vespa.hosted.controller.api.integration.configserver.Node; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * @author olaa + */ +public class MockRepairClient implements HostRepairClient { + + List<Node> updatedNodes = new ArrayList<>(); + + @Override + public void updateRepairStatus(ZoneApi zone, Map<Node, RepairTicketReport> nodes) { + updatedNodes.addAll(nodes.keySet()); + } + + @Override + public String createTicket(HostName hostname, String colo, ZoneId zoneId, String description, String category) { + throw new UnsupportedOperationException("Not implemented"); + } + + public List<Node> getUpdatedNodes() { + return updatedNodes; + } +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/RepairTicketReport.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/RepairTicketReport.java new file mode 100644 index 00000000000..c2425fe0f72 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/RepairTicketReport.java @@ -0,0 +1,63 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.integration.repair; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import static com.yahoo.yolean.Exceptions.uncheck; + +/** + * @author olaa + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class RepairTicketReport { + + private static final String REPORT_ID = "repairTicket"; + private static final ObjectMapper objectMapper = new ObjectMapper(); + + public String status; + public String ticketNumber; + public long createdMillis; + public long updatedMillis; + + public RepairTicketReport(@JsonProperty("status") String status, + @JsonProperty("ticketNumber") String ticketNumber, + @JsonProperty("createdMillis") long createdMillis, + @JsonProperty("updatedMillis") long updatedMillis) { + this.status = status; + this.ticketNumber = ticketNumber; + this.createdMillis = createdMillis; + this.updatedMillis = updatedMillis; + } + + public String getStatus() { + return status; + } + + public String getTicketNumber() { + return ticketNumber; + } + + public long getCreatedMillis() { + return createdMillis; + } + + public long getUpdatedMillis() { + return updatedMillis; + } + + public static String getReportId() { + return REPORT_ID; + } + + public static RepairTicketReport fromJsonNode(JsonNode node) { + return uncheck(() -> objectMapper.treeToValue(node, RepairTicketReport.class)); + } + + public JsonNode toJsonNode() { + return uncheck(() -> objectMapper.valueToTree(this)); + } +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/package-info.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/package-info.java new file mode 100644 index 00000000000..f53cb1ee43c --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/package-info.java @@ -0,0 +1,5 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +package com.yahoo.vespa.hosted.controller.api.integration.repair; + +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java index 5970494d471..a09dc0589ed 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java @@ -34,6 +34,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.aws.ApplicationRoles; import com.yahoo.vespa.hosted.controller.api.integration.billing.BillingController; import com.yahoo.vespa.hosted.controller.api.integration.billing.Quota; import com.yahoo.vespa.hosted.controller.api.integration.certificates.EndpointCertificateMetadata; +import com.yahoo.vespa.hosted.controller.api.integration.configserver.Cluster; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServer; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServerException; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ContainerEndpoint; @@ -353,15 +354,27 @@ public class ApplicationController { // Carry out deployment without holding the application lock. ActivateResult result = deploy(job.application(), applicationPackage, zone, platform, endpoints, endpointCertificateMetadata, applicationRoles); + // Record the quota usage for this application + var quotaUsage = deploymentQuotaUsage(zone, job.application()); + lockApplicationOrThrow(applicationId, application -> store(application.with(job.application().instance(), instance -> instance.withNewDeployment(zone, revision, platform, clock.instant(), warningsFrom(result), - QuotaUsage.create(result.quotaUsageRate()))))); + quotaUsage)))); return result; } } + private QuotaUsage deploymentQuotaUsage(ZoneId zoneId, ApplicationId applicationId) { + var quotaUsage = configServer.nodeRepository().getApplication(zoneId, applicationId) + .clusters().values().stream() + .map(Cluster::max) + .mapToDouble(max -> max.nodes() * max.nodeResources().cost()) + .sum(); + return QuotaUsage.create(quotaUsage); + } + private ApplicationPackage getApplicationPackage(ApplicationId application, ZoneId zone, ApplicationVersion revision) { return new ApplicationPackage(revision.isUnknown() ? applicationStore.getDev(application, zone) : applicationStore.get(application.tenant(), application.application(), revision)); @@ -429,11 +442,14 @@ public class ApplicationController { ActivateResult result = deploy(instanceId, applicationPackage, zone, platformVersion, endpoints, endpointCertificateMetadata, Optional.empty()); + // Record the quota usage for this application + var quotaUsage = deploymentQuotaUsage(zone, instanceId); + lockApplicationOrThrow(applicationId, application -> store(application.with(instanceId.instance(), instance -> instance.withNewDeployment(zone, applicationVersion, platformVersion, clock.instant(), warningsFrom(result), - QuotaUsage.create(result.quotaUsageRate()))))); + quotaUsage)))); return result; } } @@ -547,10 +563,8 @@ public class ApplicationController { endpoints, endpointCertificateMetadata, dockerImageRepo, domain, applicationRoles, quota)); - var quotaUsage = configServer.getQuotaUsage(new DeploymentId(application, zone)); - return new ActivateResult(new RevisionId(applicationPackage.hash()), preparedApplication.prepareResponse(), - applicationPackage.zippedContent().length, quotaUsage.rate); + applicationPackage.zippedContent().length); } finally { // Even if prepare fails, a load balancer may have been provisioned. Always refresh routing policies so that // any DNS updates can be propagated as early as possible. @@ -567,7 +581,7 @@ public class ApplicationController { PrepareResponse prepareResponse = new PrepareResponse(); prepareResponse.log = List.of(logEntry); prepareResponse.configChangeActions = new ConfigChangeActions(List.of(), List.of()); - return new ActivateResult(new RevisionId("0"), prepareResponse, 0, 0.0); + return new ActivateResult(new RevisionId("0"), prepareResponse, 0); } private LockedApplication withoutDeletedDeployments(LockedApplication application, InstanceName instance) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java index e6c9e52ff69..5379a08afc0 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java @@ -13,13 +13,11 @@ public class ActivateResult { private final RevisionId revisionId; private final PrepareResponse prepareResponse; private final long applicationZipSizeBytes; - private final double quotaUsageRate; - public ActivateResult(RevisionId revisionId, PrepareResponse prepareResponse, long applicationZipSizeBytes, double quotaUsageRate) { + public ActivateResult(RevisionId revisionId, PrepareResponse prepareResponse, long applicationZipSizeBytes) { this.revisionId = revisionId; this.prepareResponse = prepareResponse; this.applicationZipSizeBytes = applicationZipSizeBytes; - this.quotaUsageRate = quotaUsageRate; } public long applicationZipSizeBytes() { @@ -33,9 +31,4 @@ public class ActivateResult { public PrepareResponse prepareResponse() { return prepareResponse; } - - public double quotaUsageRate() { - return quotaUsageRate; - } - } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java index 124b913eb01..b4904ca3cf8 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java @@ -646,12 +646,10 @@ public class DeploymentStatus { @Override public Optional<Instant> completedAt(Change change, Optional<JobId> dependent) { return RunList.from(job) - .matching(run -> change.platform().map(run.versions().targetPlatform()::equals).orElse(true)) - .matching(run -> change.application().map(run.versions().targetApplication()::equals).orElse(true)) - .matching(run -> dependent.flatMap(status::deploymentFor) - .map(deployment -> Versions.from(change, deployment)) - .map(run.versions()::targetsMatch) - .orElse(true)) + .matching(run -> run.versions().targetsMatch(Versions.from(change, + status.application, + dependent.flatMap(status::deploymentFor), + status.systemVersion))) .status(RunStatus.success) .asList().stream() .map(run -> run.end().get()) diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java index 0e72a1b42a7..6731c30ecd7 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java @@ -45,6 +45,8 @@ public class ControllerMaintenance extends AbstractComponent { private final ResourceTagMaintainer resourceTagMaintainer; private final SystemRoutingPolicyMaintainer systemRoutingPolicyMaintainer; private final ApplicationMetaDataGarbageCollector applicationMetaDataGarbageCollector; + private final HostRepairMaintainer hostRepairMaintainer; + @Inject @SuppressWarnings("unused") // instantiated by Dependency Injection @@ -75,6 +77,7 @@ public class ControllerMaintenance extends AbstractComponent { resourceTagMaintainer = new ResourceTagMaintainer(controller, Duration.ofMinutes(30), controller.serviceRegistry().resourceTagger()); systemRoutingPolicyMaintainer = new SystemRoutingPolicyMaintainer(controller, Duration.ofMinutes(10)); applicationMetaDataGarbageCollector = new ApplicationMetaDataGarbageCollector(controller, Duration.ofHours(12)); + hostRepairMaintainer = new HostRepairMaintainer(controller, Duration.ofHours(12)); } public Upgrader upgrader() { return upgrader; } @@ -102,6 +105,7 @@ public class ControllerMaintenance extends AbstractComponent { rotationStatusUpdater.close(); resourceTagMaintainer.close(); systemRoutingPolicyMaintainer.close(); + hostRepairMaintainer.close(); } /** Create one OS upgrader per cloud found in the zone registry of controller */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirer.java index 7bd2c737fcb..37de7369452 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirer.java @@ -25,10 +25,10 @@ public class DeploymentExpirer extends ControllerMaintainer { @Override protected boolean maintain() { boolean success = true; - for (Application application : controller().applications().readable()) + for (Application application : controller().applications().readable()) { for (Instance instance : application.instances().values()) for (Deployment deployment : instance.deployments().values()) { - if ( ! isExpired(deployment)) continue; + if (!isExpired(deployment)) continue; try { log.log(Level.INFO, "Expiring deployment of " + instance.id() + " in " + deployment.zone()); @@ -40,6 +40,7 @@ public class DeploymentExpirer extends ControllerMaintainer { interval()); } } + } return success; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainer.java new file mode 100644 index 00000000000..e3c6862384f --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainer.java @@ -0,0 +1,81 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.maintenance; + +import com.yahoo.config.provision.CloudName; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.zone.ZoneApi; +import com.yahoo.vespa.hosted.controller.Controller; +import com.yahoo.vespa.hosted.controller.api.integration.configserver.Node; +import com.yahoo.vespa.hosted.controller.api.integration.configserver.NodeRepository; +import com.yahoo.vespa.hosted.controller.api.integration.repair.RepairTicketReport; +import com.yahoo.vespa.hosted.controller.api.integration.repair.HostRepairClient; +import com.yahoo.yolean.Exceptions; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +import static com.yahoo.yolean.Exceptions.uncheck; + +/** + * + * Responsible for keeping track of hosts under repair. + * + * @author olaa + */ +public class HostRepairMaintainer extends ControllerMaintainer { + + private final NodeRepository nodeRepository; + private final HostRepairClient repairClient; + + private static final Logger log = Logger.getLogger(HostRepairMaintainer.class.getName()); + + + public HostRepairMaintainer(Controller controller, Duration interval) { + super(controller, interval, null, SystemName.allOf(Predicate.not(SystemName::isPublic))); + this.nodeRepository = controller.serviceRegistry().configServer().nodeRepository(); + this.repairClient = controller.serviceRegistry().hostRepairClient(); + } + + + @Override + protected boolean maintain() { + AtomicInteger exceptions = new AtomicInteger(0); + + controller().zoneRegistry().zones() + .reachable().zones().stream() + .forEach(zoneApi -> { + var nodeTicketMap = nodeRepository.list((zoneApi).getId()) + .stream() + .filter(this::hasOpenTicket) + .collect(Collectors.toMap( + node -> node, + this::getTicketReport) + ); + try { + repairClient.updateRepairStatus(zoneApi, nodeTicketMap); + } catch (Exception e) { + log.warning("Failed to update repair status; " + Exceptions.toMessageString(e)); + exceptions.incrementAndGet(); + } + } + ); + + return exceptions.get() == 0; + } + + + private boolean hasOpenTicket(Node node) { + var reports = node.reports(); + if (!reports.containsKey(RepairTicketReport.getReportId())) { + return false; + } + return "OPEN".equals(getTicketReport(node).getStatus()); + } + + private RepairTicketReport getTicketReport(Node node) { + return uncheck(() -> RepairTicketReport.fromJsonNode(node.reports().get(RepairTicketReport.getReportId()))); + } +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java index f1306b51b39..6e5a9ddc7ab 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java @@ -24,6 +24,7 @@ import java.time.Instant; import java.util.Collection; import java.util.EnumSet; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.OptionalLong; import java.util.stream.Collectors; @@ -1216,4 +1217,42 @@ public class DeploymentTriggerTest { app.assertNotRunning(stagingTest); } + @Test + public void testTriggeringOfIdleTestJobsWhenFirstDeploymentIsOnNewerVersionThanChange() { + ApplicationPackage applicationPackage = new ApplicationPackageBuilder().systemTest() + .stagingTest() + .region("us-east-3") + .region("us-west-1") + .build(); + var app = tester.newDeploymentContext().submit(applicationPackage).deploy(); + var appToAvoidVersionGC = tester.newDeploymentContext("g", "c", "default").submit().deploy(); + + Version version2 = new Version("7.8.9"); + Version version3 = new Version("8.9.10"); + tester.controllerTester().upgradeSystem(version2); + tester.deploymentTrigger().triggerChange(appToAvoidVersionGC.instanceId(), Change.of(version2)); + appToAvoidVersionGC.deployPlatform(version2); + + // app upgrades first zone to version3, and then the other two to version2. + tester.controllerTester().upgradeSystem(version3); + tester.deploymentTrigger().triggerChange(app.instanceId(), Change.of(version3)); + app.runJob(systemTest).runJob(stagingTest); + tester.triggerJobs(); + tester.upgrader().overrideConfidence(version3, VespaVersion.Confidence.broken); + tester.controllerTester().computeVersionStatus(); + tester.upgrader().run(); + assertEquals(Optional.of(version2), app.instance().change().platform()); + + app.runJob(systemTest) + .runJob(productionUsEast3) + .runJob(stagingTest) + .runJob(productionUsWest1); + + assertEquals(version3, app.instanceJobs().get(productionUsEast3).lastSuccess().get().versions().targetPlatform()); + assertEquals(version2, app.instanceJobs().get(productionUsWest1).lastSuccess().get().versions().targetPlatform()); + assertEquals(Map.of(), app.deploymentStatus().jobsToRun()); + assertEquals(Change.empty(), app.instance().change()); + assertEquals(List.of(), tester.jobs().active()); + } + } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/QuotaUsageTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/QuotaUsageTest.java new file mode 100644 index 00000000000..d2901aeac97 --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/QuotaUsageTest.java @@ -0,0 +1,21 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.deployment; + +import com.yahoo.config.provision.zone.ZoneId; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author ogronnesby + */ +public class QuotaUsageTest { + + @Test + public void testQuotaUsageIsPersisted() { + var tester = new DeploymentTester(); + var context = tester.newDeploymentContext().submit().deploy(); + assertEquals(1.062, context.deployment(ZoneId.from("prod.us-west-1")).quota().rate(), 0.01); + } + +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/NodeRepositoryMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/NodeRepositoryMock.java index 90276b6b590..72cc000ef98 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/NodeRepositoryMock.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/NodeRepositoryMock.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.integration; +import com.fasterxml.jackson.databind.JsonNode; import com.yahoo.component.Version; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.HostName; @@ -226,6 +227,11 @@ public class NodeRepositoryMock implements NodeRepository { nodeRepository.get(zoneId).remove(HostName.from(hostName)); } + @Override + public void patchNode(ZoneId zoneId, String hostName, NodeRepositoryNode node) { + throw new UnsupportedOperationException(); + } + public Optional<Duration> osUpgradeBudget(ZoneId zone, NodeType type, Version version) { return Optional.ofNullable(osUpgradeBudgets.get(Objects.hash(zone, type, version))); } @@ -264,4 +270,8 @@ public class NodeRepositoryMock implements NodeRepository { modifyNodes(deployment, hostname, node -> new Node.Builder(node).rebootGeneration(node.rebootGeneration() + 1).build()); } + public void addReport(ZoneId zoneId, HostName hostName, String reportId, JsonNode report) { + nodeRepository.get(zoneId).get(hostName).reports().put(reportId, report); + } + } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java index 1b21f7db7c4..3ec02c6ceb7 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java @@ -20,6 +20,8 @@ import com.yahoo.vespa.hosted.controller.api.integration.dns.MemoryNameService; import com.yahoo.vespa.hosted.controller.api.integration.entity.MemoryEntityService; import com.yahoo.vespa.hosted.controller.api.integration.organization.MockContactRetriever; import com.yahoo.vespa.hosted.controller.api.integration.organization.MockIssueHandler; +import com.yahoo.vespa.hosted.controller.api.integration.repair.MockRepairClient; +import com.yahoo.vespa.hosted.controller.api.integration.repair.HostRepairClient; import com.yahoo.vespa.hosted.controller.api.integration.resource.CostReportConsumerMock; import com.yahoo.vespa.hosted.controller.api.integration.routing.GlobalRoutingService; import com.yahoo.vespa.hosted.controller.api.integration.routing.MemoryGlobalRoutingService; @@ -61,6 +63,7 @@ public class ServiceRegistryMock extends AbstractComponent implements ServiceReg private final MockResourceTagger mockResourceTagger = new MockResourceTagger(); private final ApplicationRoleService applicationRoleService = new NoopApplicationRoleService(); private final BillingController billingController = new MockBillingController(); + private final MockRepairClient repairClient = new MockRepairClient(); public ServiceRegistryMock(SystemName system) { this.zoneRegistryMock = new ZoneRegistryMock(system); @@ -192,6 +195,11 @@ public class ServiceRegistryMock extends AbstractComponent implements ServiceReg return billingController; } + @Override + public MockRepairClient hostRepairClient() { + return repairClient; + } + public ConfigServerMock configServerMock() { return configServerMock; } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainerTest.java new file mode 100644 index 00000000000..556755581fe --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainerTest.java @@ -0,0 +1,51 @@ +package com.yahoo.vespa.hosted.controller.maintenance; + +import com.yahoo.config.provision.HostName; +import com.yahoo.config.provision.zone.ZoneId; +import com.yahoo.vespa.hosted.controller.ControllerTester; +import com.yahoo.vespa.hosted.controller.api.integration.noderepository.NodeRepositoryNode; +import com.yahoo.vespa.hosted.controller.api.integration.repair.HostRepairClient; +import com.yahoo.vespa.hosted.controller.api.integration.repair.MockRepairClient; +import com.yahoo.vespa.hosted.controller.api.integration.repair.RepairTicketReport; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; + +import static org.junit.Assert.*; + +/** + * @author olaa + */ +public class HostRepairMaintainerTest { + + private final ControllerTester tester = new ControllerTester(); + private final HostRepairMaintainer maintainer = new HostRepairMaintainer(tester.controller(), Duration.ofHours(12)); + + @Test + public void maintain() { + var zoneId = ZoneId.from("dev.us-east-1"); + var hostname1 = HostName.from("node-1-tenant-host-dev.us-east-1"); + var hostname2 = HostName.from("node-2-tenant-host-dev.us-east-1"); + var timestamp = Instant.now().toEpochMilli(); + var openTicket = new RepairTicketReport("OPEN", "ticket-1", timestamp, timestamp); + var closedTicket = new RepairTicketReport("CLOSED", "ticket-2", timestamp, timestamp); + + tester.configServer().nodeRepository().addReport( + zoneId, + hostname1, + RepairTicketReport.getReportId(), + openTicket.toJsonNode()); + tester.configServer().nodeRepository().addReport( + zoneId, + hostname2, + RepairTicketReport.getReportId(), + closedTicket.toJsonNode()); + + maintainer.maintain(); + var updatedNodes = tester.serviceRegistry().hostRepairClient().getUpdatedNodes(); + assertEquals(1, updatedNodes.size()); + assertEquals(hostname1, updatedNodes.get(0).hostname()); + } +}
\ No newline at end of file diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json index 385f0fbc3cf..bb3578b2482 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json @@ -28,6 +28,9 @@ "name": "DeploymentMetricsMaintainer" }, { + "name": "HostRepairMaintainer" + }, + { "name": "JobRunner" }, { diff --git a/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp b/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp index 0299dc3ebba..7182d66f8aa 100644 --- a/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp +++ b/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp @@ -36,7 +36,6 @@ #include <vespa/eval/eval/tensor_function.h> #include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/eval/tensor/default_value_builder_factory.h> -#include <vespa/eval/tensor/mixed/packed_mixed_tensor_builder_factory.h> #include <vespa/vespalib/util/benchmark_timer.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/objects/nbostream.h> @@ -230,7 +229,6 @@ Impl default_tensor_engine_impl(1, "DefaultTensorEngine", "OLD PROD", DefaultTe Impl simple_value_impl(3, " SimpleValue", " SimpleV", SimpleValueBuilderFactory::get(), false); Impl fast_value_impl(0, " FastValue", "NEW PROD", FastValueBuilderFactory::get(), false); Impl optimized_fast_value_impl(2, "Optimized FastValue", "Optimize", FastValueBuilderFactory::get(), true); -Impl packed_mixed_tensor_impl(5, " PackedMixedTensor", " Packed", PackedMixedTensorBuilderFactory::get(), false); Impl default_tensor_value_impl(4, " DefaultValue", "DefaultV", DefaultValueBuilderFactory::get(), false); vespalib::string short_header("--------"); @@ -243,7 +241,6 @@ std::vector<CREF<Impl>> impl_list = {default_tensor_engine_impl, simple_value_impl, fast_value_impl, optimized_fast_value_impl, - packed_mixed_tensor_impl, default_tensor_value_impl}; //----------------------------------------------------------------------------- @@ -982,6 +979,14 @@ void print_summary() { } int main(int argc, char **argv) { + const std::string run_only_prod_option = "--limit-implementations"; + if ((argc > 1) && (argv[1] == run_only_prod_option )) { + impl_list.clear(); + impl_list.push_back(fast_value_impl); + impl_list.push_back(default_tensor_engine_impl); + ++argv; + --argc; + } ::testing::InitGoogleTest(&argc, argv); int result = RUN_ALL_TESTS(); print_summary(); diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml index 4beaf6086a6..181ef6dffbd 100644 --- a/fat-model-dependencies/pom.xml +++ b/fat-model-dependencies/pom.xml @@ -221,5 +221,10 @@ <artifactId>jdisc_http_service</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>${protobuf.version}</version> + </dependency> </dependencies> </project> diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index c56de1bb178..dd6d84e3ad7 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -161,7 +161,7 @@ public class Flags { ZONE_ID, APPLICATION_ID); public static final UnboundBooleanFlag USE_CONTENT_NODE_BTREE_DB = defineFeatureFlag( - "use-content-node-btree-db", false, + "use-content-node-btree-db", true, "Whether to use the new B-tree bucket database on the content node.", "Takes effect at restart of content node process", ZONE_ID, APPLICATION_ID); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java index c73a19bd9e2..eace7457615 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java @@ -61,7 +61,7 @@ public class Autoscaler { private Optional<AllocatableClusterResources> autoscale(Cluster cluster, List<Node> clusterNodes, Limits limits, boolean exclusive) { - if (unstable(clusterNodes)) return Optional.empty(); + if (unstable(clusterNodes, nodeRepository)) return Optional.empty(); AllocatableClusterResources currentAllocation = new AllocatableClusterResources(clusterNodes, nodeRepository); @@ -111,10 +111,18 @@ public class Autoscaler { return 20; } - public static boolean unstable(List<Node> nodes) { - return nodes.stream().anyMatch(node -> node.status().wantToRetire() || - node.allocation().get().membership().retired() || - node.allocation().get().isRemovable()); + public static boolean unstable(List<Node> nodes, NodeRepository nodeRepository) { + // The cluster is processing recent changes + if (nodes.stream().anyMatch(node -> node.status().wantToRetire() || + node.allocation().get().membership().retired() || + node.allocation().get().isRemovable())) + return true; + + // A deployment is ongoing + if (nodeRepository.getNodes(nodes.get(0).allocation().get().owner(), Node.State.reserved).size() > 0) + return true; + + return false; } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcher.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcher.java index b4a63175548..4597fc04e17 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcher.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcher.java @@ -56,7 +56,7 @@ public class MetricsV2MetricsFetcher extends AbstractComponent implements Metric NodeList applicationNodes = nodeRepository.list(application).state(Node.State.active); // Do not try to draw conclusions from utilization while unstable - if (Autoscaler.unstable(applicationNodes.asList())) return Collections.emptyList(); + if (Autoscaler.unstable(applicationNodes.asList(), nodeRepository)) return Collections.emptyList(); Optional<Node> metricsV2Container = applicationNodes.container() .matching(node -> expectedUp(node)) diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java index 3b01f678982..c0fd7df9b2e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java @@ -28,7 +28,6 @@ import java.util.stream.Collectors; */ public class AutoscalingMaintainer extends NodeRepositoryMaintainer { - private final MetricsDb metricsDb; private final Autoscaler autoscaler; private final Deployer deployer; private final Metric metric; @@ -40,7 +39,6 @@ public class AutoscalingMaintainer extends NodeRepositoryMaintainer { Duration interval) { super(nodeRepository, interval, metric); this.autoscaler = new Autoscaler(metricsDb, nodeRepository); - this.metricsDb = metricsDb; this.metric = metric; this.deployer = deployer; } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java index c9538d878f2..9ef5a841a7a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.provision.maintenance; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ApplicationLockException; import com.yahoo.config.provision.ClusterResources; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.jdisc.Metric; @@ -39,32 +40,39 @@ public class ScalingSuggestionsMaintainer extends NodeRepositoryMaintainer { @Override protected boolean maintain() { - boolean success = true; - if ( ! nodeRepository().zone().environment().isProduction()) return success; + if ( ! nodeRepository().zone().environment().isProduction()) return true; - activeNodesByApplication().forEach((applicationId, nodes) -> suggest(applicationId, nodes)); - return success; + int successes = 0; + for (var application : activeNodesByApplication().entrySet()) + successes += suggest(application.getKey(), application.getValue()); + return successes > 0; } - private void suggest(ApplicationId application, List<Node> applicationNodes) { - nodesByCluster(applicationNodes).forEach((clusterId, clusterNodes) -> - suggest(application, clusterId, clusterNodes)); + private int suggest(ApplicationId application, List<Node> applicationNodes) { + int successes = 0; + for (var cluster : nodesByCluster(applicationNodes).entrySet()) + successes += suggest(application, cluster.getKey(), cluster.getValue()) ? 1 : 0; + return successes; } private Applications applications() { return nodeRepository().applications(); } - private void suggest(ApplicationId applicationId, - ClusterSpec.Id clusterId, - List<Node> clusterNodes) { + private boolean suggest(ApplicationId applicationId, + ClusterSpec.Id clusterId, + List<Node> clusterNodes) { Application application = applications().get(applicationId).orElse(new Application(applicationId)); Optional<Cluster> cluster = application.cluster(clusterId); - if (cluster.isEmpty()) return; + if (cluster.isEmpty()) return true; Optional<ClusterResources> suggestion = autoscaler.suggest(cluster.get(), clusterNodes); // Wait only a short time for the lock to avoid interfering with change deployments try (Mutex lock = nodeRepository().lock(applicationId, Duration.ofSeconds(1))) { applications().get(applicationId).ifPresent(a -> storeSuggestion(suggestion, clusterId, a, lock)); + return true; + } + catch (ApplicationLockException e) { + return false; } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java index 240963a8c0d..1e98160955c 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java @@ -112,8 +112,11 @@ class NodeAllocation { boolean resizeable = requestedNodes.considerRetiring() && candidate.isResizable; boolean acceptToRetire = acceptToRetire(candidate); - if ((! saturated() && hasCompatibleFlavor(candidate) && requestedNodes.acceptable(candidate)) || acceptToRetire) - accepted.add(acceptNode(candidate, shouldRetire(candidate), resizeable)); + if ((! saturated() && hasCompatibleFlavor(candidate) && requestedNodes.acceptable(candidate)) || acceptToRetire) { + candidate = candidate.withNode(); + if (candidate.isValid()) + accepted.add(acceptNode(candidate, shouldRetire(candidate), resizeable)); + } } else if (! saturated() && hasCompatibleFlavor(candidate)) { if ( ! nodeResourceLimits.isWithinRealLimits(candidate, cluster)) { @@ -240,7 +243,6 @@ class NodeAllocation { } private Node acceptNode(NodeCandidate candidate, boolean wantToRetire, boolean resizeable) { - candidate = candidate.withNode(); Node node = candidate.toNode(); if (node.allocation().isPresent()) // Record the currently requested resources @@ -356,7 +358,7 @@ class NodeAllocation { candidate = candidate.withNode(); Allocation allocation = candidate.allocation().get(); candidate = candidate.withNode(candidate.toNode().with(allocation.with(allocation.membership() - .with(allocation.membership().cluster().exclusive(requestedNodes.isExclusive()))))); + .with(allocation.membership().cluster().exclusive(requestedNodes.isExclusive()))))); nodes.put(candidate.toNode().hostname(), candidate); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeCandidate.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeCandidate.java index b915053fff5..02086e2bace 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeCandidate.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeCandidate.java @@ -87,7 +87,11 @@ abstract class NodeCandidate implements Nodelike, Comparable<NodeCandidate> { /** Returns a copy of this with exclusive switch set to given value */ public abstract NodeCandidate withExclusiveSwitch(boolean exclusiveSwitch); - /** Returns the node instance of this candidate, or an invalid node if it cannot be created */ + /** + * Returns the node instance of this candidate, allocating it if necessary. + * + * @throws IllegalStateException if the node candidate is invalid + */ public abstract Node toNode(); /** Returns whether this node can - as far as we know - be used to run the application workload */ @@ -358,10 +362,12 @@ abstract class NodeCandidate implements Nodelike, Comparable<NodeCandidate> { Optional<IP.Allocation> allocation; try { allocation = parent.get().ipConfig().pool().findAllocation(allNodes, nodeRepository.nameResolver()); - if (allocation.isEmpty()) return new InvalidNodeCandidate(resources, freeParentCapacity, parent.get()); + if (allocation.isEmpty()) return new InvalidNodeCandidate(resources, freeParentCapacity, parent.get(), + "No IP addresses available on parent host"); } catch (Exception e) { log.warning("Failed allocating IP address on " + parent.get() +": " + Exceptions.toMessageString(e)); - return new InvalidNodeCandidate(resources, freeParentCapacity, parent.get()); + return new InvalidNodeCandidate(resources, freeParentCapacity, parent.get(), + "Failed when allocating IP address on host"); } Node node = Node.createDockerNode(allocation.get().addresses(), @@ -409,10 +415,13 @@ abstract class NodeCandidate implements Nodelike, Comparable<NodeCandidate> { static class InvalidNodeCandidate extends NodeCandidate { private final NodeResources resources; + private final String invalidReason; - private InvalidNodeCandidate(NodeResources resources, NodeResources freeParentCapacity, Node parent) { + private InvalidNodeCandidate(NodeResources resources, NodeResources freeParentCapacity, Node parent, + String invalidReason) { super(freeParentCapacity, Optional.of(parent), false, false, false, true, false); this.resources = resources; + this.invalidReason = invalidReason; } @Override @@ -453,7 +462,7 @@ abstract class NodeCandidate implements Nodelike, Comparable<NodeCandidate> { @Override public Node toNode() { - throw new IllegalStateException("Candidate node on " + parent.get() + " is invalid"); + throw new IllegalStateException("Candidate node on " + parent.get() + " is invalid: " + invalidReason); } @Override diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java index 3aa6253979d..de7adf9fa2d 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java @@ -87,19 +87,14 @@ public class NodePrioritizer { /** Returns the list of nodes sorted by {@link NodeCandidate#compareTo(NodeCandidate)} */ private List<NodeCandidate> prioritize() { - // Group candidates by their cluster switch - Map<ClusterSwitch, List<NodeCandidate>> candidatesBySwitch = this.nodes.stream().collect(Collectors.groupingBy(candidate -> { - Node nodeOnSwitch = candidate.parent.orElseGet(candidate::toNode); - ClusterSpec.Id cluster = candidate.toNode().allocation() - .map(a -> a.membership().cluster().id()) - .orElseGet(clusterSpec::id); - return ClusterSwitch.from(cluster, nodeOnSwitch.switchHostname()); - })); + // Group candidates by their switch hostname + Map<Optional<String>, List<NodeCandidate>> candidatesBySwitch = this.nodes.stream() + .collect(Collectors.groupingBy(candidate -> candidate.parent.orElseGet(candidate::toNode).switchHostname())); // Mark lower priority nodes on shared switch as non-exclusive List<NodeCandidate> nodes = new ArrayList<>(this.nodes.size()); for (var clusterSwitch : candidatesBySwitch.keySet()) { List<NodeCandidate> switchCandidates = candidatesBySwitch.get(clusterSwitch); - if (clusterSwitch.equals(ClusterSwitch.unknown)) { + if (clusterSwitch.isEmpty()) { nodes.addAll(switchCandidates); // Nodes are on exclusive switch by default } else { Collections.sort(switchCandidates); @@ -156,6 +151,7 @@ public class NodePrioritizer { .filter(node -> legalStates.contains(node.state())) .filter(node -> node.allocation().isPresent()) .filter(node -> node.allocation().get().owner().equals(application)) + .filter(node -> node.allocation().get().membership().cluster().id().equals(clusterSpec.id())) .filter(node -> node.state() == Node.State.active || canStillAllocateToParentOf(node)) .map(node -> candidateFrom(node, false)) .forEach(nodes::add); @@ -206,43 +202,9 @@ public class NodePrioritizer { */ private boolean canStillAllocateToParentOf(Node node) { if (node.parentHostname().isEmpty()) return true; - Optional<Node> parent = node.parentHostname().flatMap(nodeRepository::getNode); + Optional<Node> parent = allNodes.parentOf(node); if (parent.isEmpty()) return false; return nodeRepository.canAllocateTenantNodeTo(parent.get()); } - /** A cluster and its network switch */ - private static class ClusterSwitch { - - private static final ClusterSwitch unknown = new ClusterSwitch(null, null); - - private final ClusterSpec.Id cluster; - private final String switchHostname; - - public ClusterSwitch(ClusterSpec.Id cluster, String switchHostname) { - this.cluster = cluster; - this.switchHostname = switchHostname; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ClusterSwitch that = (ClusterSwitch) o; - return Objects.equals(cluster, that.cluster) && - Objects.equals(switchHostname, that.switchHostname); - } - - @Override - public int hashCode() { - return Objects.hash(cluster, switchHostname); - } - - public static ClusterSwitch from(ClusterSpec.Id cluster, Optional<String> switchHostname) { - if (switchHostname.isEmpty()) return unknown; - return new ClusterSwitch(cluster, switchHostname.get()); - } - - } - } diff --git a/searchcore/src/tests/proton/attribute/attribute_populator/attribute_populator_test.cpp b/searchcore/src/tests/proton/attribute/attribute_populator/attribute_populator_test.cpp index 683a6cd1197..2a5444c2525 100644 --- a/searchcore/src/tests/proton/attribute/attribute_populator/attribute_populator_test.cpp +++ b/searchcore/src/tests/proton/attribute/attribute_populator/attribute_populator_test.cpp @@ -97,12 +97,12 @@ TEST_F("require that reprocess with document populates attribute", Fixture) f._pop->handleExisting(5, f._ctx.create(0, 33)); EXPECT_EQUAL(6u, attr->get()->getNumDocs()); EXPECT_EQUAL(33, attr->get()->getInt(5)); - EXPECT_EQUAL(1u, attr->get()->getStatus().getLastSyncToken()); + EXPECT_EQUAL(0u, attr->get()->getStatus().getLastSyncToken()); f._pop->handleExisting(6, f._ctx.create(1, 44)); EXPECT_EQUAL(7u, attr->get()->getNumDocs()); EXPECT_EQUAL(44, attr->get()->getInt(6)); - EXPECT_EQUAL(2u, attr->get()->getStatus().getLastSyncToken()); + EXPECT_EQUAL(0u, attr->get()->getStatus().getLastSyncToken()); f._pop->done(); EXPECT_EQUAL(CREATE_SERIAL_NUM, attr->get()->getStatus().getLastSyncToken()); } diff --git a/searchcore/src/tests/proton/attribute/attribute_test.cpp b/searchcore/src/tests/proton/attribute/attribute_test.cpp index b30420ead24..c98127f4daf 100644 --- a/searchcore/src/tests/proton/attribute/attribute_test.cpp +++ b/searchcore/src/tests/proton/attribute/attribute_test.cpp @@ -169,29 +169,34 @@ public: _mgr->addAttribute(attr->getName(), std::move(attr)); allocAttributeWriter(); } - void put(SerialNum serialNum, const Document &doc, DocumentIdT lid, - bool immediateCommit = true) { - _aw->put(serialNum, doc, lid, immediateCommit, emptyCallback); + void put(SerialNum serialNum, const Document &doc, DocumentIdT lid) { + _aw->put(serialNum, doc, lid, emptyCallback); + commit(serialNum); } void update(SerialNum serialNum, const DocumentUpdate &upd, - DocumentIdT lid, bool immediateCommit, IFieldUpdateCallback & onUpdate) { - _aw->update(serialNum, upd, lid, immediateCommit, emptyCallback, onUpdate); + DocumentIdT lid, IFieldUpdateCallback & onUpdate) { + _aw->update(serialNum, upd, lid, emptyCallback, onUpdate); + commit(serialNum); } - void update(SerialNum serialNum, const Document &doc, - DocumentIdT lid, bool immediateCommit) { - _aw->update(serialNum, doc, lid, immediateCommit, emptyCallback); + void update(SerialNum serialNum, const Document &doc, DocumentIdT lid) { + _aw->update(serialNum, doc, lid, emptyCallback); + commit(serialNum); } - void remove(SerialNum serialNum, DocumentIdT lid, bool immediateCommit = true) { - _aw->remove(serialNum, lid, immediateCommit, emptyCallback); + void remove(SerialNum serialNum, DocumentIdT lid) { + _aw->remove(serialNum, lid, emptyCallback); + commit(serialNum); } - void remove(const LidVector &lidVector, SerialNum serialNum, bool immediateCommit = true) { - _aw->remove(lidVector, serialNum, immediateCommit, emptyCallback); + void remove(const LidVector &lidVector, SerialNum serialNum) { + _aw->remove(lidVector, serialNum, emptyCallback); + commit(serialNum); } void commit(SerialNum serialNum) { _aw->forceCommit(serialNum, emptyCallback); } void assertExecuteHistory(std::vector<uint32_t> expExecuteHistory) { - EXPECT_EQ(expExecuteHistory, _attributeFieldWriter->getExecuteHistory()); + auto includeCommit = expExecuteHistory; + includeCommit.insert(includeCommit.end(), expExecuteHistory.begin(), expExecuteHistory.end()); + EXPECT_EQ(includeCommit, _attributeFieldWriter->getExecuteHistory()); } SerialNum test_force_commit(AttributeVector &attr, SerialNum serialNum) { commit(serialNum); @@ -400,29 +405,29 @@ TEST_F(AttributeWriterTest, visibility_delay_is_honoured) EXPECT_EQ(2u, a1->getNumDocs()); EXPECT_EQ(3u, a1->getStatus().getLastSyncToken()); AttributeWriter awDelayed(_mgr); - awDelayed.put(4, *doc, 2, false, emptyCallback); + awDelayed.put(4, *doc, 2, emptyCallback); EXPECT_EQ(3u, a1->getNumDocs()); EXPECT_EQ(3u, a1->getStatus().getLastSyncToken()); - awDelayed.put(5, *doc, 4, false, emptyCallback); + awDelayed.put(5, *doc, 4, emptyCallback); EXPECT_EQ(5u, a1->getNumDocs()); EXPECT_EQ(3u, a1->getStatus().getLastSyncToken()); awDelayed.forceCommit(6, emptyCallback); EXPECT_EQ(6u, a1->getStatus().getLastSyncToken()); AttributeWriter awDelayedShort(_mgr); - awDelayedShort.put(7, *doc, 2, false, emptyCallback); + awDelayedShort.put(7, *doc, 2, emptyCallback); EXPECT_EQ(6u, a1->getStatus().getLastSyncToken()); - awDelayedShort.put(8, *doc, 2, false, emptyCallback); + awDelayedShort.put(8, *doc, 2, emptyCallback); awDelayedShort.forceCommit(8, emptyCallback); EXPECT_EQ(8u, a1->getStatus().getLastSyncToken()); verifyAttributeContent(*a1, 2, "10"); awDelayed.put(9, *idb.startDocument("id:ns:searchdocument::1").startAttributeField("a1").addStr("11").endField().endDocument(), - 2, false, emptyCallback); + 2, emptyCallback); awDelayed.put(10, *idb.startDocument("id:ns:searchdocument::1").startAttributeField("a1").addStr("20").endField().endDocument(), - 2, false, emptyCallback); + 2, emptyCallback); awDelayed.put(11, *idb.startDocument("id:ns:searchdocument::1").startAttributeField("a1").addStr("30").endField().endDocument(), - 2, false, emptyCallback); + 2, emptyCallback); EXPECT_EQ(8u, a1->getStatus().getLastSyncToken()); verifyAttributeContent(*a1, 2, "10"); awDelayed.forceCommit(12, emptyCallback); @@ -472,8 +477,7 @@ TEST_F(AttributeWriterTest, handles_update) .addUpdate(ArithmeticValueUpdate(ArithmeticValueUpdate::Add, 10))); DummyFieldUpdateCallback onUpdate; - bool immediateCommit = true; - update(2, upd, 1, immediateCommit, onUpdate); + update(2, upd, 1, onUpdate); attribute::IntegerContent ibuf; ibuf.fill(*a1, 1); @@ -483,9 +487,9 @@ TEST_F(AttributeWriterTest, handles_update) EXPECT_EQ(1u, ibuf.size()); EXPECT_EQ(30u, ibuf[0]); - update(2, upd, 1, immediateCommit, onUpdate); // same sync token as previous + update(2, upd, 1, onUpdate); // same sync token as previous try { - update(1, upd, 1, immediateCommit, onUpdate); // lower sync token than previous + update(1, upd, 1, onUpdate); // lower sync token than previous EXPECT_TRUE(true); // update is ignored } catch (vespalib::IllegalStateException & e) { LOG(info, "Got expected exception: '%s'", e.getMessage().c_str()); @@ -517,9 +521,8 @@ TEST_F(AttributeWriterTest, handles_predicate_update) PredicateIndex &index = static_cast<PredicateAttribute &>(*a1).getIndex(); EXPECT_EQ(1u, index.getZeroConstraintDocs().size()); EXPECT_FALSE(index.getIntervalIndex().lookup(PredicateHash::hash64("foo=bar")).valid()); - bool immediateCommit = true; DummyFieldUpdateCallback onUpdate; - update(2, upd, 1, immediateCommit, onUpdate); + update(2, upd, 1, onUpdate); EXPECT_EQ(0u, index.getZeroConstraintDocs().size()); EXPECT_TRUE(index.getIntervalIndex().lookup(PredicateHash::hash64("foo=bar")).valid()); } @@ -712,9 +715,8 @@ TEST_F(AttributeWriterTest, handles_tensor_assign_update) new_value = EngineOrFactory::get().copy(*new_tensor); upd.addUpdate(FieldUpdate(upd.getType().getField("a1")) .addUpdate(AssignValueUpdate(new_value))); - bool immediateCommit = true; DummyFieldUpdateCallback onUpdate; - update(2, upd, 1, immediateCommit, onUpdate); + update(2, upd, 1, onUpdate); EXPECT_EQ(2u, a1->getNumDocs()); EXPECT_TRUE(tensorAttribute != nullptr); tensor2 = tensorAttribute->getTensor(1); @@ -1078,7 +1080,7 @@ TEST_F(StructArrayWriterTest, update_with_doc_argument_updates_struct_field_attr put(10, *doc, 1); checkAttrs(1, 10, {11, 12}); doc = makeDoc(20, {21}); - update(11, *doc, 1, true); + update(11, *doc, 1); checkAttrs(1, 10, {21}); } @@ -1135,7 +1137,7 @@ TEST_F(StructMapWriterTest, update_with_doc_argument_updates_struct_field_attrib put(10, *doc, 1); checkAttrs(1, 10, {{1, 11}, {2, 12}}); doc = makeDoc(20, {{42, 21}}); - update(11, *doc, 1, true); + update(11, *doc, 1); checkAttrs(1, 10, {{42, 21}}); } diff --git a/searchcore/src/tests/proton/docsummary/docsummary.cpp b/searchcore/src/tests/proton/docsummary/docsummary.cpp index 266a817d380..1d8b85864f6 100644 --- a/searchcore/src/tests/proton/docsummary/docsummary.cpp +++ b/searchcore/src/tests/proton/docsummary/docsummary.cpp @@ -97,7 +97,7 @@ public: BuildContext(const Schema &schema) : _dmk("summary"), _bld(schema), - _repo(new DocumentTypeRepo(_bld.getDocumentType())), + _repo(std::make_shared<DocumentTypeRepo>(_bld.getDocumentType())), _summaryExecutor(4, 128 * 1024), _noTlSyncer(), _str(_summaryExecutor, "summary", @@ -125,7 +125,7 @@ public: } FieldCacheRepo::UP createFieldCacheRepo(const ResultConfig &resConfig) const { - return FieldCacheRepo::UP(new FieldCacheRepo(resConfig, _bld.getDocumentType())); + return std::make_unique<FieldCacheRepo>(resConfig, _bld.getDocumentType()); } }; @@ -150,8 +150,7 @@ vespalib::string asVstring(const Inspector &value) { } void decode(const ResEntry *entry, vespalib::Slime &slime) { - vespalib::Memory mem(entry->_dataval, - entry->_datalen); + vespalib::Memory mem(entry->_dataval, entry->_datalen); size_t decodeRes = BinaryFormat::decode(mem, slime); ASSERT_EQUAL(decodeRes, mem.size); } @@ -216,14 +215,14 @@ public: if (! FastOS_File::MakeDirectory((std::string("tmpdb/") + docTypeName).c_str())) { LOG_ABORT("should not be reached"); } - _ddb.reset(new DocumentDB("tmpdb", _configMgr.getConfig(), "tcp/localhost:9013", _queryLimiter, _clock, - DocTypeName(docTypeName), makeBucketSpace(), - *b->getProtonConfigSP(), *this, _summaryExecutor, _summaryExecutor, - _tls, _dummy, _fileHeaderContext, ConfigStore::UP(new MemoryConfigStore), - std::make_shared<vespalib::ThreadStackExecutor>(16, 128 * 1024), _hwInfo)), + _ddb = std::make_unique<DocumentDB>("tmpdb", _configMgr.getConfig(), "tcp/localhost:9013", _queryLimiter, _clock, + DocTypeName(docTypeName), makeBucketSpace(), *b->getProtonConfigSP(), *this, + _summaryExecutor, _summaryExecutor, _tls, _dummy, _fileHeaderContext, + std::make_unique<MemoryConfigStore>(), + std::make_shared<vespalib::ThreadStackExecutor>(16, 128 * 1024), _hwInfo), _ddb->start(); _ddb->waitForOnlineState(); - _aw = AttributeWriter::UP(new AttributeWriter(_ddb->getReadySubDB()->getAttributeManager())); + _aw = std::make_unique<AttributeWriter>(_ddb->getReadySubDB()->getAttributeManager()); _sa = _ddb->getReadySubDB()->getSummaryAdapter(); } ~DBContext() @@ -246,7 +245,8 @@ public: Timestamp(0u), docSize, lid, 0u)); LOG_ASSERT(putRes.ok()); uint64_t serialNum = _ddb->getFeedHandler().incSerialNum(); - _aw->put(serialNum, doc, lid, true, std::shared_ptr<IDestructorCallback>()); + _aw->put(serialNum, doc, lid, std::shared_ptr<IDestructorCallback>()); + _aw->forceCommit(serialNum, std::shared_ptr<IDestructorCallback>()); _ddb->getReadySubDB()->getAttributeManager()->getAttributeFieldWriter().sync(); _sa->put(serialNum, lid, doc); const GlobalId &gid = docId.getGlobalId(); @@ -259,10 +259,11 @@ public: op->setSerialNum(serialNum); op->setDbDocumentId(dbdId); op->setPrevDbDocumentId(prevDbdId); - _ddb->getWriteService().master().execute(vespalib::makeLambdaTask([this, op = std::move(op)]() { - _ddb->getFeedHandler().appendOperation(*op, std::make_shared<search::IgnoreCallback>()); + vespalib::Gate commitDone; + _ddb->getWriteService().master().execute(vespalib::makeLambdaTask([this, op = std::move(op), &commitDone]() { + _ddb->getFeedHandler().appendOperation(*op, std::make_shared<search::GateCallback>(commitDone)); })); - _ddb->getWriteService().master().sync(); + commitDone.await(); SearchView *sv(dynamic_cast<SearchView *>(_ddb->getReadySubDB()->getSearchView().get())); if (sv != nullptr) { // cf. FeedView::putAttributes() diff --git a/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp b/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp index 43f16e87986..754cf4ea15d 100644 --- a/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp +++ b/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp @@ -769,18 +769,27 @@ struct DocumentHandler } void putDoc(PutOperation &op) { IFeedView::SP feedView = _f._subDb.getFeedView(); - _f.runInMaster([&]() { feedView->preparePut(op); - feedView->handlePut(FeedToken(), op); } ); + _f.runInMaster([&]() { + feedView->preparePut(op); + feedView->handlePut(FeedToken(), op); + feedView->forceCommit(op.getSerialNum()); + } ); } void moveDoc(MoveOperation &op) { IFeedView::SP feedView = _f._subDb.getFeedView(); - _f.runInMaster([&]() { feedView->handleMove(op, IDestructorCallback::SP()); } ); + _f.runInMaster([&]() { + feedView->handleMove(op, IDestructorCallback::SP()); + feedView->forceCommit(op.getSerialNum()); + } ); } void removeDoc(RemoveOperation &op) { IFeedView::SP feedView = _f._subDb.getFeedView(); - _f.runInMaster([&]() { feedView->prepareRemove(op); - feedView->handleRemove(FeedToken(), op); } ); + _f.runInMaster([&]() { + feedView->prepareRemove(op); + feedView->handleRemove(FeedToken(), op); + feedView->forceCommit(op.getSerialNum()); + } ); } void putDocs() { PutOperation putOp = createPut(std::move(createDoc(1, 22, 33)), Timestamp(10), 10); diff --git a/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp b/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp index 1269804c98a..9bb8865707d 100644 --- a/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp +++ b/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp @@ -92,20 +92,17 @@ struct MyTracer _os << ")"; } - void tracePut(const vespalib::string &adapterType, - SerialNum serialNum, uint32_t lid, bool immediateCommit) { + void tracePut(const vespalib::string &adapterType, SerialNum serialNum, uint32_t lid) { Guard guard(_mutex); addComma(); _os << "put(adapter=" << adapterType << - ",serialNum=" << serialNum << ",lid=" << lid << ",commit=" << immediateCommit << ")"; + ",serialNum=" << serialNum << ",lid=" << lid << ")"; } - void traceRemove(const vespalib::string &adapterType, - SerialNum serialNum, uint32_t lid, bool immediateCommit) { + void traceRemove(const vespalib::string &adapterType, SerialNum serialNum, uint32_t lid) { Guard guard(_mutex); addComma(); - _os << "remove(adapter=" << adapterType << - ",serialNum=" << serialNum << ",lid=" << lid << ",commit=" << immediateCommit << ")"; + _os << "remove(adapter=" << adapterType << ",serialNum=" << serialNum << ",lid=" << lid << ")"; } void traceCommit(const vespalib::string &adapterType, SerialNum serialNum) { @@ -151,12 +148,12 @@ struct MyIndexWriter : public test::MockIndexWriter {} void put(SerialNum serialNum, const document::Document &doc, const DocumentIdT lid) override { (void) doc; - _tracer.tracePut(indexAdapterTypeName, serialNum, lid, false); + _tracer.tracePut(indexAdapterTypeName, serialNum, lid); } void remove(SerialNum serialNum, const search::DocumentIdT lid) override { LOG(info, "MyIndexAdapter::remove(): serialNum(%" PRIu64 "), docId(%u)", serialNum, lid); _removes.push_back(lid); - _tracer.traceRemove(indexAdapterTypeName, serialNum, lid, false); + _tracer.traceRemove(indexAdapterTypeName, serialNum, lid); } void commit(SerialNum serialNum, OnWriteDoneType) override { ++_commitCount; @@ -335,35 +332,26 @@ struct MyAttributeWriter : public IAttributeWriter AttrMap::const_iterator itr = _attrMap.find(attrName); return ((itr == _attrMap.end()) ? nullptr : itr->second.get()); } - void put(SerialNum serialNum, const document::Document &doc, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType) override { + void put(SerialNum serialNum, const document::Document &doc, DocumentIdT lid, OnWriteDoneType) override { _putSerial = serialNum; _putDocId = doc.getId(); _putLid = lid; - _tracer.tracePut(attributeAdapterTypeName, serialNum, lid, immediateCommit); - if (immediateCommit) { - ++_commitCount; - } + _tracer.tracePut(attributeAdapterTypeName, serialNum, lid); } - void remove(SerialNum serialNum, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType) override { + void remove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType) override { _removeSerial = serialNum; _removeLid = lid; - _tracer.traceRemove(attributeAdapterTypeName, serialNum, lid, immediateCommit); - if (immediateCommit) { - ++_commitCount; - } + _tracer.traceRemove(attributeAdapterTypeName, serialNum, lid); } - void remove(const LidVector & lidsToRemove, SerialNum serialNum, - bool immediateCommit, OnWriteDoneType) override { + void remove(const LidVector & lidsToRemove, SerialNum serialNum, OnWriteDoneType) override { for (uint32_t lid : lidsToRemove) { LOG(info, "MyAttributeAdapter::remove(): serialNum(%" PRIu64 "), docId(%u)", serialNum, lid); _removes.push_back(lid); - _tracer.traceRemove(attributeAdapterTypeName, serialNum, lid, immediateCommit); + _tracer.traceRemove(attributeAdapterTypeName, serialNum, lid); } } void update(SerialNum serialNum, const document::DocumentUpdate &upd, - DocumentIdT lid, bool, OnWriteDoneType, IFieldUpdateCallback & onUpdate) override { + DocumentIdT lid, OnWriteDoneType, IFieldUpdateCallback & onUpdate) override { _updateSerial = serialNum; _updateDocId = upd.getId(); _updateLid = lid; @@ -372,12 +360,10 @@ struct MyAttributeWriter : public IAttributeWriter onUpdate.onUpdateField(fieldUpdate.getField().getName(), attr); } } - void update(SerialNum serialNum, const document::Document &doc, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType) override { + void update(SerialNum serialNum, const document::Document &doc, DocumentIdT lid, OnWriteDoneType) override { (void) serialNum; (void) doc; (void) lid; - (void) immediateCommit; } void heartBeat(SerialNum) override { ++_heartBeatCount; } void compactLidSpace(uint32_t wantedLidLimit, SerialNum ) override { @@ -818,6 +804,7 @@ TEST_F("require that put() calls attribute adapter", SearchableFeedViewFixture) DocumentContext dc = f.doc1(); EXPECT_EQUAL(0u, f._docIdLimit.get()); f.putAndWait(dc); + f.forceCommitAndWait(); EXPECT_EQUAL(1u, f.maw._putSerial); EXPECT_EQUAL(DocumentId("id:ns:searchdocument::1"), f.maw._putDocId); @@ -1184,26 +1171,6 @@ TEST_F("require that compactLidSpace() propagates to index writer", EXPECT_EQUAL(2u, f.miw._wantedLidLimit); } -TEST_F("require that commit is called if visibility delay is 0", - SearchableFeedViewFixture) -{ - DocumentContext dc = f.doc1(); - f.putAndWait(dc); - EXPECT_EQUAL(1u, f.miw._commitCount); - EXPECT_EQUAL(1u, f.maw._commitCount); - f.removeAndWait(dc); - EXPECT_EQUAL(2u, f.miw._commitCount); - EXPECT_EQUAL(2u, f.maw._commitCount); - f.assertTrace("put(adapter=attribute,serialNum=1,lid=1,commit=1)," - "put(adapter=index,serialNum=1,lid=1,commit=0)," - "commit(adapter=index,serialNum=1)," - "ack(Result(0, ))," - "remove(adapter=attribute,serialNum=2,lid=1,commit=1)," - "remove(adapter=index,serialNum=2,lid=1,commit=0)," - "commit(adapter=index,serialNum=2)," - "ack(Result(0, ))"); -} - const vespalib::duration LONG_DELAY = 60s; const vespalib::duration SHORT_DELAY = 500ms; @@ -1219,11 +1186,11 @@ TEST_F("require that commit is not called when inside a commit interval", EXPECT_EQUAL(0u, f.miw._commitCount); EXPECT_EQUAL(0u, f.maw._commitCount); EXPECT_EQUAL(0u, f._docIdLimit.get()); - f.assertTrace("put(adapter=attribute,serialNum=1,lid=1,commit=0)," - "put(adapter=index,serialNum=1,lid=1,commit=0)," + f.assertTrace("put(adapter=attribute,serialNum=1,lid=1)," + "put(adapter=index,serialNum=1,lid=1)," "ack(Result(0, ))," - "remove(adapter=attribute,serialNum=2,lid=1,commit=0)," - "remove(adapter=index,serialNum=2,lid=1,commit=0)," + "remove(adapter=attribute,serialNum=2,lid=1)," + "remove(adapter=index,serialNum=2,lid=1)," "ack(Result(0, ))"); f.forceCommitAndWait(); } @@ -1242,11 +1209,11 @@ TEST_F("require that commit is not implicitly called", EXPECT_EQUAL(0u, f.miw._commitCount); EXPECT_EQUAL(0u, f.maw._commitCount); EXPECT_EQUAL(0u, f._docIdLimit.get()); - f.assertTrace("put(adapter=attribute,serialNum=1,lid=1,commit=0)," - "put(adapter=index,serialNum=1,lid=1,commit=0)," + f.assertTrace("put(adapter=attribute,serialNum=1,lid=1)," + "put(adapter=index,serialNum=1,lid=1)," "ack(Result(0, ))," - "remove(adapter=attribute,serialNum=2,lid=1,commit=0)," - "remove(adapter=index,serialNum=2,lid=1,commit=0)," + "remove(adapter=attribute,serialNum=2,lid=1)," + "remove(adapter=index,serialNum=2,lid=1)," "ack(Result(0, ))"); f.forceCommitAndWait(); } @@ -1263,8 +1230,8 @@ TEST_F("require that forceCommit updates docid limit", EXPECT_EQUAL(1u, f.miw._commitCount); EXPECT_EQUAL(1u, f.maw._commitCount); EXPECT_EQUAL(2u, f._docIdLimit.get()); - f.assertTrace("put(adapter=attribute,serialNum=1,lid=1,commit=0)," - "put(adapter=index,serialNum=1,lid=1,commit=0)," + f.assertTrace("put(adapter=attribute,serialNum=1,lid=1)," + "put(adapter=index,serialNum=1,lid=1)," "ack(Result(0, ))," "commit(adapter=attribute,serialNum=1)," "commit(adapter=index,serialNum=1)"); diff --git a/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp b/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp index e6e71d51e47..3a75f8cd494 100644 --- a/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp +++ b/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp @@ -105,12 +105,12 @@ struct MyMinimalFeedView : public MyMinimalFeedViewBase, public StoreOnlyFeedVie outstandingMoveOps(outstandingMoveOps_) { } - void removeAttributes(SerialNum s, const LidVector &l, bool immediateCommit, OnWriteDoneType onWriteDone) override { - StoreOnlyFeedView::removeAttributes(s, l, immediateCommit, onWriteDone); + void removeAttributes(SerialNum s, const LidVector &l, OnWriteDoneType onWriteDone) override { + StoreOnlyFeedView::removeAttributes(s, l, onWriteDone); ++removeMultiAttributesCount; } - void removeIndexedFields(SerialNum s, const LidVector &l, bool immediateCommit, OnWriteDoneType onWriteDone) override { - StoreOnlyFeedView::removeIndexedFields(s, l, immediateCommit, onWriteDone); + void removeIndexedFields(SerialNum s, const LidVector &l, OnWriteDoneType onWriteDone) override { + StoreOnlyFeedView::removeIndexedFields(s, l, onWriteDone); ++removeMultiIndexFieldsCount; } void heartBeatIndexedFields(SerialNum s) override { @@ -145,23 +145,23 @@ struct MoveOperationFeedView : public MyMinimalFeedView { removeIndexFieldsCount(0), onWriteDoneContexts() {} - void putAttributes(SerialNum, search::DocumentIdT, const document::Document &, bool, OnPutDoneType onWriteDone) override { + void putAttributes(SerialNum, search::DocumentIdT, const document::Document &, OnPutDoneType onWriteDone) override { ++putAttributesCount; EXPECT_EQUAL(1, outstandingMoveOps); onWriteDoneContexts.push_back(onWriteDone); } void putIndexedFields(SerialNum, search::DocumentIdT, const document::Document::SP &, - bool, OnOperationDoneType onWriteDone) override { + OnOperationDoneType onWriteDone) override { ++putIndexFieldsCount; EXPECT_EQUAL(1, outstandingMoveOps); onWriteDoneContexts.push_back(onWriteDone); } - void removeAttributes(SerialNum, search::DocumentIdT, bool, OnRemoveDoneType onWriteDone) override { + void removeAttributes(SerialNum, search::DocumentIdT, OnRemoveDoneType onWriteDone) override { ++removeAttributesCount; EXPECT_EQUAL(1, outstandingMoveOps); onWriteDoneContexts.push_back(onWriteDone); } - void removeIndexedFields(SerialNum, search::DocumentIdT, bool, OnRemoveDoneType onWriteDone) override { + void removeIndexedFields(SerialNum, search::DocumentIdT, OnRemoveDoneType onWriteDone) override { ++removeIndexFieldsCount; EXPECT_EQUAL(1, outstandingMoveOps); onWriteDoneContexts.push_back(onWriteDone); diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_populator.cpp b/searchcore/src/vespa/searchcore/proton/attribute/attribute_populator.cpp index 33a5776cb8a..af7bae32b11 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_populator.cpp +++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_populator.cpp @@ -3,6 +3,8 @@ #include "attribute_populator.h" #include <vespa/searchcore/proton/common/eventlogger.h> #include <vespa/searchlib/common/idestructorcallback.h> +#include <vespa/searchlib/common/gatecallback.h> +#include <vespa/vespalib/util/gate.h> #include <vespa/searchlib/attribute/attributevector.h> #include <vespa/log/log.h> @@ -73,8 +75,10 @@ void AttributePopulator::handleExisting(uint32_t lid, const std::shared_ptr<document::Document> &doc) { search::SerialNum serialNum(nextSerialNum()); - auto populateDoneContext = std::make_shared<PopulateDoneContext>(doc); - _writer.put(serialNum, *doc, lid, true, populateDoneContext); + _writer.put(serialNum, *doc, lid, std::make_shared<PopulateDoneContext>(doc)); + vespalib::Gate gate; + _writer.forceCommit(serialNum, std::make_shared<search::GateCallback>(gate)); + gate.await(); } void diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp index bf32b679d76..2b859c17931 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp +++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp @@ -127,8 +127,7 @@ ensureLidSpace(SerialNum serialNum, DocumentIdT lid, AttributeVector &attr) void applyPutToAttribute(SerialNum serialNum, const FieldValue::UP &fieldValue, DocumentIdT lid, - bool immediateCommit, AttributeVector &attr, - AttributeWriter::OnWriteDoneType) + AttributeVector &attr, AttributeWriter::OnWriteDoneType) { ensureLidSpace(serialNum, lid, attr); if (fieldValue.get()) { @@ -136,9 +135,6 @@ applyPutToAttribute(SerialNum serialNum, const FieldValue::UP &fieldValue, Docum } else { attr.clearDoc(lid); } - if (immediateCommit) { - attr.commit(serialNum, serialNum); - } } void @@ -147,7 +143,6 @@ complete_put_to_attribute(SerialNum serial_num, AttributeVector& attr, const FieldValue::SP& field_value, std::future<std::unique_ptr<PrepareResult>>& result_future, - bool immediate_commit, AttributeWriter::OnWriteDoneType) { ensureLidSpace(serial_num, docid, attr); @@ -157,20 +152,14 @@ complete_put_to_attribute(SerialNum serial_num, } else { attr.clearDoc(docid); } - if (immediate_commit) { - attr.commit(serial_num, serial_num); - } } void -applyRemoveToAttribute(SerialNum serialNum, DocumentIdT lid, bool immediateCommit, +applyRemoveToAttribute(SerialNum serialNum, DocumentIdT lid, AttributeVector &attr, AttributeWriter::OnWriteDoneType) { ensureLidSpace(serialNum, lid, attr); attr.clearDoc(lid); - if (immediateCommit) { - attr.commit(serialNum, serialNum); - } } void @@ -182,15 +171,6 @@ applyUpdateToAttribute(SerialNum serialNum, const FieldUpdate &fieldUpd, } void -applyUpdateToAttributeAndCommit(SerialNum serialNum, const FieldUpdate &fieldUpd, - DocumentIdT lid, AttributeVector &attr) -{ - ensureLidSpace(serialNum, lid, attr); - AttributeUpdater::handleUpdate(attr, lid, fieldUpd); - attr.commit(serialNum, serialNum); -} - -void applyReplayDone(uint32_t docIdLimit, AttributeVector &attr) { AttributeManager::padAttribute(attr, docIdLimit); @@ -240,30 +220,22 @@ using AttrUpdates = std::vector<std::pair<AttributeVector *, const FieldUpdate * struct BatchUpdateTask : public vespalib::Executor::Task { - BatchUpdateTask(SerialNum serialNum, DocumentIdT lid, bool immediateCommit) + BatchUpdateTask(SerialNum serialNum, DocumentIdT lid) : vespalib::Executor::Task(), _serialNum(serialNum), _lid(lid), - _immediateCommit(immediateCommit), _onWriteDone() { } ~BatchUpdateTask() override; void run() override { - if (_immediateCommit) { - for (const auto & update : _updates) { - applyUpdateToAttributeAndCommit(_serialNum, *update.second, _lid, *update.first); - } - } else { - for (const auto & update : _updates) { - applyUpdateToAttribute(_serialNum, *update.second, _lid, *update.first); - } + for (const auto & update : _updates) { + applyUpdateToAttribute(_serialNum, *update.second, _lid, *update.first); } } SerialNum _serialNum; DocumentIdT _lid; - bool _immediateCommit; AttrUpdates _updates; search::IDestructorCallback::SP _onWriteDone; }; @@ -310,22 +282,20 @@ class PutTask : public vespalib::Executor::Task const AttributeWriter::WriteContext &_wc; const SerialNum _serialNum; const uint32_t _lid; - const bool _immediateCommit; const bool _allAttributes; std::remove_reference_t<AttributeWriter::OnWriteDoneType> _onWriteDone; std::shared_ptr<DocumentFieldExtractor> _fieldExtractor; std::vector<FieldValue::UP> _fieldValues; public: - PutTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, std::shared_ptr<DocumentFieldExtractor> fieldExtractor, uint32_t lid, bool immediateCommit, bool allAttributes, AttributeWriter::OnWriteDoneType onWriteDone); + PutTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, std::shared_ptr<DocumentFieldExtractor> fieldExtractor, uint32_t lid, bool allAttributes, AttributeWriter::OnWriteDoneType onWriteDone); ~PutTask() override; void run() override; }; -PutTask::PutTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, std::shared_ptr<DocumentFieldExtractor> fieldExtractor, uint32_t lid, bool immediateCommit, bool allAttributes, AttributeWriter::OnWriteDoneType onWriteDone) +PutTask::PutTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, std::shared_ptr<DocumentFieldExtractor> fieldExtractor, uint32_t lid, bool allAttributes, AttributeWriter::OnWriteDoneType onWriteDone) : _wc(wc), _serialNum(serialNum), _lid(lid), - _immediateCommit(immediateCommit), _allAttributes(allAttributes), _onWriteDone(onWriteDone), _fieldExtractor(std::move(fieldExtractor)), @@ -352,7 +322,7 @@ PutTask::run() if (_allAttributes || field.isStructFieldAttribute()) { AttributeVector &attr = field.getAttribute(); if (attr.getStatus().getLastSyncToken() < _serialNum) { - applyPutToAttribute(_serialNum, _fieldValues[fieldId], _lid, _immediateCommit, attr, _onWriteDone); + applyPutToAttribute(_serialNum, _fieldValues[fieldId], _lid, attr, _onWriteDone); } ++fieldId; } @@ -418,26 +388,22 @@ private: AttributeVector& _attr; FieldValue::SP _field_value; std::future<std::unique_ptr<PrepareResult>> _result_future; - const bool _immediate_commit; std::remove_reference_t<AttributeWriter::OnWriteDoneType> _on_write_done; public: CompletePutTask(PreparePutTask& prepare_task, - bool immediate_commit, AttributeWriter::OnWriteDoneType on_write_done); ~CompletePutTask() override; void run() override; }; CompletePutTask::CompletePutTask(PreparePutTask& prepare_task, - bool immediate_commit, AttributeWriter::OnWriteDoneType on_write_done) : _serial_num(prepare_task.serial_num()), _docid(prepare_task.docid()), _attr(prepare_task.attr()), _field_value(prepare_task.field_value()), _result_future(prepare_task.result_future()), - _immediate_commit(immediate_commit), _on_write_done(on_write_done) { } @@ -448,8 +414,7 @@ void CompletePutTask::run() { if (_attr.getStatus().getLastSyncToken() < _serial_num) { - complete_put_to_attribute(_serial_num, _docid, _attr, _field_value, _result_future, - _immediate_commit, _on_write_done); + complete_put_to_attribute(_serial_num, _docid, _attr, _field_value, _result_future, _on_write_done); } } @@ -458,19 +423,17 @@ class RemoveTask : public vespalib::Executor::Task const AttributeWriter::WriteContext &_wc; const SerialNum _serialNum; const uint32_t _lid; - const bool _immediateCommit; std::remove_reference_t<AttributeWriter::OnWriteDoneType> _onWriteDone; public: - RemoveTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, uint32_t lid, bool immediateCommit, AttributeWriter::OnWriteDoneType onWriteDone); + RemoveTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, uint32_t lid, AttributeWriter::OnWriteDoneType onWriteDone); ~RemoveTask() override; void run() override; }; -RemoveTask::RemoveTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, uint32_t lid, bool immediateCommit, AttributeWriter::OnWriteDoneType onWriteDone) +RemoveTask::RemoveTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, uint32_t lid, AttributeWriter::OnWriteDoneType onWriteDone) : _wc(wc), _serialNum(serialNum), _lid(lid), - _immediateCommit(immediateCommit), _onWriteDone(onWriteDone) { } @@ -485,7 +448,7 @@ RemoveTask::run() AttributeVector &attr = field.getAttribute(); // Must use <= due to how move operations are handled if (attr.getStatus().getLastSyncToken() <= _serialNum) { - applyRemoveToAttribute(_serialNum, _lid, _immediateCommit, attr, _onWriteDone); + applyRemoveToAttribute(_serialNum, _lid, attr, _onWriteDone); } } } @@ -496,18 +459,15 @@ private: const AttributeWriter::WriteContext &_writeCtx; const SerialNum _serialNum; const LidVector _lidsToRemove; - const bool _immediateCommit; std::remove_reference_t<AttributeWriter::OnWriteDoneType> _onWriteDone; public: BatchRemoveTask(const AttributeWriter::WriteContext &writeCtx, SerialNum serialNum, const LidVector &lidsToRemove, - bool immediateCommit, AttributeWriter::OnWriteDoneType onWriteDone) : _writeCtx(writeCtx), _serialNum(serialNum), _lidsToRemove(lidsToRemove), - _immediateCommit(immediateCommit), _onWriteDone(onWriteDone) {} ~BatchRemoveTask() override; @@ -516,10 +476,7 @@ public: auto &attr = field.getAttribute(); if (attr.getStatus().getLastSyncToken() < _serialNum) { for (auto lidToRemove : _lidsToRemove) { - applyRemoveToAttribute(_serialNum, lidToRemove, false, attr, _onWriteDone); - } - if (_immediateCommit) { - attr.commit(_serialNum, _serialNum); + applyRemoveToAttribute(_serialNum, lidToRemove, attr, _onWriteDone); } } } @@ -604,7 +561,7 @@ AttributeWriter::buildFieldPaths(const DocumentType & docType, const DataType *d void AttributeWriter::internalPut(SerialNum serialNum, const Document &doc, DocumentIdT lid, - bool immediateCommit, bool allAttributes, OnWriteDoneType onWriteDone) + bool allAttributes, OnWriteDoneType onWriteDone) { const DataType *dataType(doc.getDataType()); if (_dataType != dataType) { @@ -615,13 +572,12 @@ AttributeWriter::internalPut(SerialNum serialNum, const Document &doc, DocumentI if (wc.use_two_phase_put()) { assert(wc.getFields().size() == 1); auto prepare_task = std::make_unique<PreparePutTask>(serialNum, lid, wc.getFields()[0], extractor); - auto complete_task = std::make_unique<CompletePutTask>(*prepare_task, immediateCommit, onWriteDone); + auto complete_task = std::make_unique<CompletePutTask>(*prepare_task, onWriteDone); _shared_executor.execute(std::move(prepare_task)); _attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(complete_task)); } else { if (allAttributes || wc.hasStructFieldAttribute()) { - auto putTask = std::make_unique<PutTask>(wc, serialNum, extractor, lid, immediateCommit, allAttributes, - onWriteDone); + auto putTask = std::make_unique<PutTask>(wc, serialNum, extractor, lid, allAttributes, onWriteDone); _attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(putTask)); } } @@ -629,11 +585,10 @@ AttributeWriter::internalPut(SerialNum serialNum, const Document &doc, DocumentI } void -AttributeWriter::internalRemove(SerialNum serialNum, DocumentIdT lid, bool immediateCommit, - OnWriteDoneType onWriteDone) +AttributeWriter::internalRemove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType onWriteDone) { for (const auto &wc : _writeContexts) { - auto removeTask = std::make_unique<RemoveTask>(wc, serialNum, lid, immediateCommit, onWriteDone); + auto removeTask = std::make_unique<RemoveTask>(wc, serialNum, lid, onWriteDone); _attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(removeTask)); } } @@ -678,50 +633,46 @@ AttributeWriter::getWritableAttribute(const vespalib::string &name) const } void -AttributeWriter::put(SerialNum serialNum, const Document &doc, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone) +AttributeWriter::put(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone) { LOG(spam, "Handle put: serial(%" PRIu64 "), docId(%s), lid(%u), document(%s)", serialNum, doc.getId().toString().c_str(), lid, doc.toString(true).c_str()); - internalPut(serialNum, doc, lid, immediateCommit, true, onWriteDone); + internalPut(serialNum, doc, lid, true, onWriteDone); } void -AttributeWriter::update(SerialNum serialNum, const Document &doc, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone) +AttributeWriter::update(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone) { LOG(spam, "Handle update: serial(%" PRIu64 "), docId(%s), lid(%u), document(%s)", serialNum, doc.getId().toString().c_str(), lid, doc.toString(true).c_str()); - internalPut(serialNum, doc, lid, immediateCommit, false, onWriteDone); + internalPut(serialNum, doc, lid, false, onWriteDone); } void -AttributeWriter::remove(SerialNum serialNum, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone) +AttributeWriter::remove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType onWriteDone) { - internalRemove(serialNum, lid, immediateCommit, onWriteDone); + internalRemove(serialNum, lid, onWriteDone); } void -AttributeWriter::remove(const LidVector &lidsToRemove, SerialNum serialNum, - bool immediateCommit, OnWriteDoneType onWriteDone) +AttributeWriter::remove(const LidVector &lidsToRemove, SerialNum serialNum, OnWriteDoneType onWriteDone) { for (const auto &writeCtx : _writeContexts) { - auto removeTask = std::make_unique<BatchRemoveTask>(writeCtx, serialNum, lidsToRemove, immediateCommit, onWriteDone); + auto removeTask = std::make_unique<BatchRemoveTask>(writeCtx, serialNum, lidsToRemove, onWriteDone); _attributeFieldWriter.executeTask(writeCtx.getExecutorId(), std::move(removeTask)); } } void AttributeWriter::update(SerialNum serialNum, const DocumentUpdate &upd, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate) + OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate) { LOG(debug, "Inspecting update for document %d.", lid); std::vector<std::unique_ptr<BatchUpdateTask>> args; uint32_t numExecutors = _attributeFieldWriter.getNumExecutors(); args.reserve(numExecutors); for (uint32_t i(0); i < numExecutors; i++) { - args.emplace_back(std::make_unique<BatchUpdateTask>(serialNum, lid, immediateCommit)); + args.emplace_back(std::make_unique<BatchUpdateTask>(serialNum, lid)); args.back()->_updates.reserve((2*upd.getUpdates().size())/numExecutors); } diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h index 9e9e8910669..f63a2c6efba 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h +++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h @@ -79,13 +79,12 @@ private: void setupAttriuteMapping(); void buildFieldPaths(const DocumentType &docType, const DataType *dataType); void internalPut(SerialNum serialNum, const Document &doc, DocumentIdT lid, - bool immediateCommit, bool allAttributes, OnWriteDoneType onWriteDone); - void internalRemove(SerialNum serialNum, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone); + bool allAttributes, OnWriteDoneType onWriteDone); + void internalRemove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType onWriteDone); public: AttributeWriter(proton::IAttributeManager::SP mgr); - ~AttributeWriter(); + ~AttributeWriter() override; /* Only for in tests that add attributes after AttributeWriter construction. */ @@ -94,16 +93,12 @@ public: */ std::vector<search::AttributeVector *> getWritableAttributes() const override; search::AttributeVector *getWritableAttribute(const vespalib::string &name) const override; - void put(SerialNum serialNum, const Document &doc, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone) override; - void remove(SerialNum serialNum, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone) override; - void remove(const LidVector &lidVector, SerialNum serialNum, - bool immediateCommit, OnWriteDoneType onWriteDone) override; + void put(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone) override; + void remove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType onWriteDone) override; + void remove(const LidVector &lidVector, SerialNum serialNum, OnWriteDoneType onWriteDone) override; void update(SerialNum serialNum, const DocumentUpdate &upd, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate) override; - void update(SerialNum serialNum, const Document &doc, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone) override; + OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate) override; + void update(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone) override; void heartBeat(SerialNum serialNum) override; void compactLidSpace(uint32_t wantedLidLimit, SerialNum serialNum) override; const proton::IAttributeManager::SP &getAttributeManager() const override { diff --git a/searchcore/src/vespa/searchcore/proton/attribute/filter_attribute_manager.cpp b/searchcore/src/vespa/searchcore/proton/attribute/filter_attribute_manager.cpp index 3b1269b031c..ffdfdbc4332 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/filter_attribute_manager.cpp +++ b/searchcore/src/vespa/searchcore/proton/attribute/filter_attribute_manager.cpp @@ -233,7 +233,7 @@ FilterAttributeManager::setImportedAttributes(std::unique_ptr<ImportedAttributes const ImportedAttributesRepo * FilterAttributeManager::getImportedAttributes() const { - throw vespalib::IllegalArgumentException("Not implemented"); + return nullptr; } std::shared_ptr<search::attribute::ReadableAttributeVector> diff --git a/searchcore/src/vespa/searchcore/proton/attribute/i_attribute_writer.h b/searchcore/src/vespa/searchcore/proton/attribute/i_attribute_writer.h index 99b5728fd3a..789a8077cba 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/i_attribute_writer.h +++ b/searchcore/src/vespa/searchcore/proton/attribute/i_attribute_writer.h @@ -33,27 +33,23 @@ public: typedef document::Document Document; using OnWriteDoneType = const std::shared_ptr<search::IDestructorCallback> &; - virtual ~IAttributeWriter() {} + virtual ~IAttributeWriter() = default; virtual std::vector<search::AttributeVector *> getWritableAttributes() const = 0; virtual search::AttributeVector *getWritableAttribute(const vespalib::string &attrName) const = 0; - virtual void put(SerialNum serialNum, const Document &doc, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone) = 0; - virtual void remove(SerialNum serialNum, DocumentIdT lid, bool immediateCommit, - OnWriteDoneType onWriteDone) = 0; - virtual void remove(const LidVector &lidVector, SerialNum serialNum, - bool immediateCommit, OnWriteDoneType onWriteDone) = 0; + virtual void put(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone) = 0; + virtual void remove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType onWriteDone) = 0; + virtual void remove(const LidVector &lidVector, SerialNum serialNum, OnWriteDoneType onWriteDone) = 0; /** * Update the underlying attributes based on the content of the given DocumentUpdate. * The OnWriteDoneType instance should ensure the lifetime of the given DocumentUpdate instance. */ virtual void update(SerialNum serialNum, const DocumentUpdate &upd, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate) = 0; + OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate) = 0; /* * Update the underlying struct field attributes based on updated document. */ - virtual void update(SerialNum serialNum, const Document &doc, DocumentIdT lid, - bool immediateCommit, OnWriteDoneType onWriteDone) = 0; + virtual void update(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone) = 0; virtual void heartBeat(SerialNum serialNum) = 0; /** * Compact the lid space of the underlying attribute vectors. diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.cpp b/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.cpp index e547f3556be..7fab995dfb9 100644 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.cpp +++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.cpp @@ -31,7 +31,6 @@ DocumentMetaStoreAttribute::DocumentMetaStoreAttribute(const vespalib::string &n { } -DocumentMetaStoreAttribute::~DocumentMetaStoreAttribute() -{ } +DocumentMetaStoreAttribute::~DocumentMetaStoreAttribute() = default; } diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.h b/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.h index 5b286907fb8..721aa8fe126 100644 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.h +++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.h @@ -14,22 +14,20 @@ namespace proton { class DocumentMetaStoreAttribute : public search::NotImplementedAttribute { protected: - virtual void notImplemented() const override __attribute__((noinline)); + void notImplemented() const override __attribute__((noinline)); public: DocumentMetaStoreAttribute(const vespalib::string &name=getFixedName()); - virtual ~DocumentMetaStoreAttribute(); + ~DocumentMetaStoreAttribute() override; static const vespalib::string &getFixedName(); // Implements IAttributeVector - virtual size_t - getFixedWidth() const override - { + size_t getFixedWidth() const override { return document::GlobalId::LENGTH; } - virtual void onCommit() override {} + void onCommit() override {} }; } diff --git a/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.cpp b/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.cpp index eec58ed53dc..52b4d869ce8 100644 --- a/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.cpp @@ -20,46 +20,39 @@ namespace proton { * Otherwise we can drop it and ack the operation right away. */ void -FastAccessFeedView::putAttributes(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, - bool immediateCommit, OnPutDoneType onWriteDone) +FastAccessFeedView::putAttributes(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, OnPutDoneType onWriteDone) { - _attributeWriter->put(serialNum, doc, lid, immediateCommit, onWriteDone); - if (immediateCommit && onWriteDone) { - onWriteDone->registerPutLid(&_docIdLimit); - } + _attributeWriter->put(serialNum, doc, lid, onWriteDone); } void FastAccessFeedView::updateAttributes(SerialNum serialNum, search::DocumentIdT lid, const DocumentUpdate &upd, - bool immediateCommit, OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate) + OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate) { - _attributeWriter->update(serialNum, upd, lid, immediateCommit, onWriteDone, onUpdate); + _attributeWriter->update(serialNum, upd, lid, onWriteDone, onUpdate); } void -FastAccessFeedView::updateAttributes(SerialNum serialNum, Lid lid, FutureDoc futureDoc, - bool immediateCommit, OnOperationDoneType onWriteDone) +FastAccessFeedView::updateAttributes(SerialNum serialNum, Lid lid, FutureDoc futureDoc, OnOperationDoneType onWriteDone) { if (_attributeWriter->hasStructFieldAttribute()) { const std::unique_ptr<const Document> & doc = futureDoc.get(); if (doc) { - _attributeWriter->update(serialNum, *doc, lid, immediateCommit, onWriteDone); + _attributeWriter->update(serialNum, *doc, lid, onWriteDone); } } } void -FastAccessFeedView::removeAttributes(SerialNum serialNum, search::DocumentIdT lid, - bool immediateCommit, OnRemoveDoneType onWriteDone) +FastAccessFeedView::removeAttributes(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone) { - _attributeWriter->remove(serialNum, lid, immediateCommit, onWriteDone); + _attributeWriter->remove(serialNum, lid, onWriteDone); } void -FastAccessFeedView::removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove, - bool immediateCommit, OnWriteDoneType onWriteDone) +FastAccessFeedView::removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone) { - _attributeWriter->remove(lidsToRemove, serialNum, immediateCommit, onWriteDone); + _attributeWriter->remove(lidsToRemove, serialNum, onWriteDone); } void diff --git a/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.h b/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.h index 08f11869b08..e0823be3e43 100644 --- a/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.h +++ b/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.h @@ -36,18 +36,14 @@ private: const IAttributeWriter::SP _attributeWriter; DocIdLimit &_docIdLimit; - void putAttributes(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, - bool immediateCommit, OnPutDoneType onWriteDone) override; + void putAttributes(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, OnPutDoneType onWriteDone) override; void updateAttributes(SerialNum serialNum, search::DocumentIdT lid, const document::DocumentUpdate &upd, - bool immediateCommit, OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate) override; - void updateAttributes(SerialNum serialNum, Lid lid, FutureDoc doc, - bool immediateCommit, OnOperationDoneType onWriteDone) override; - void removeAttributes(SerialNum serialNum, search::DocumentIdT lid, - bool immediateCommit, OnRemoveDoneType onWriteDone) override; - - void removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove, - bool immediateCommit, OnWriteDoneType onWriteDone) override; + OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate) override; + void updateAttributes(SerialNum serialNum, Lid lid, FutureDoc doc, OnOperationDoneType onWriteDone) override; + void removeAttributes(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone) override; + + void removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone) override; void heartBeatAttributes(SerialNum serialNum) override; diff --git a/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.cpp b/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.cpp index 360cac6e2ee..ebef7b4b6d4 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.cpp @@ -59,46 +59,41 @@ SearchableFeedView::sync() void SearchableFeedView::putIndexedFields(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &newDoc, - bool immediateCommit, OnOperationDoneType onWriteDone) + OnOperationDoneType onWriteDone) { if (!hasIndexedFields()) { return; } _writeService.index().execute( makeLambdaTask([=] { - performIndexPut(serialNum, lid, newDoc, immediateCommit, onWriteDone); + performIndexPut(serialNum, lid, newDoc, onWriteDone); })); } void -SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, - bool immediateCommit, OnOperationDoneType onWriteDone) +SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, OnOperationDoneType onWriteDone) { + (void) onWriteDone; assert(_writeService.index().isCurrentThread()); VLOG(getDebugLevel(lid, doc.getId()), "database(%s): performIndexPut: serialNum(%" PRIu64 "), docId(%s), lid(%d)", _params._docTypeName.toString().c_str(), serialNum, doc.getId().toString().c_str(), lid); _indexWriter->put(serialNum, doc, lid); - if (immediateCommit) { - _indexWriter->commit(serialNum, onWriteDone); - } } void -SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &doc, - bool immediateCommit, OnOperationDoneType onWriteDone) +SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &doc, OnOperationDoneType onWriteDone) { - performIndexPut(serialNum, lid, *doc, immediateCommit, onWriteDone); + performIndexPut(serialNum, lid, *doc, onWriteDone); } void -SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, FutureDoc futureDoc, - bool immediateCommit, OnOperationDoneType onWriteDone) +SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, FutureDoc futureDoc, OnOperationDoneType onWriteDone) { const auto &doc = futureDoc.get(); if (doc) { - performIndexPut(serialNum, lid, *doc, immediateCommit, onWriteDone); + performIndexPut(serialNum, lid, *doc, onWriteDone); } } @@ -115,49 +110,44 @@ SearchableFeedView::performIndexHeartBeat(SerialNum serialNum) } void -SearchableFeedView::updateIndexedFields(SerialNum serialNum, search::DocumentIdT lid, FutureDoc futureDoc, - bool immediateCommit, OnOperationDoneType onWriteDone) +SearchableFeedView::updateIndexedFields(SerialNum serialNum, search::DocumentIdT lid, FutureDoc futureDoc, OnOperationDoneType onWriteDone) { _writeService.index().execute( makeLambdaTask([serialNum, lid, futureDoc = std::move(futureDoc), - immediateCommit, onWriteDone = std::move(onWriteDone), this]() mutable { - performIndexPut(serialNum, lid, std::move(futureDoc), immediateCommit, std::move(onWriteDone)); + onWriteDone = std::move(onWriteDone), this]() mutable { + performIndexPut(serialNum, lid, std::move(futureDoc), std::move(onWriteDone)); })); } void -SearchableFeedView::removeIndexedFields(SerialNum serialNum, search::DocumentIdT lid, - bool immediateCommit, OnRemoveDoneType onWriteDone) +SearchableFeedView::removeIndexedFields(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone) { if (!hasIndexedFields()) { return; } _writeService.index().execute( makeLambdaTask([=]() { - performIndexRemove(serialNum, lid, immediateCommit, onWriteDone); + performIndexRemove(serialNum, lid, onWriteDone); })); } void -SearchableFeedView::performIndexRemove(SerialNum serialNum, search::DocumentIdT lid, - bool immediateCommit, OnRemoveDoneType onWriteDone) +SearchableFeedView::performIndexRemove(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone) { + (void) onWriteDone; assert(_writeService.index().isCurrentThread()); VLOG(getDebugLevel(lid, nullptr), "database(%s): performIndexRemove: serialNum(%" PRIu64 "), lid(%d)", _params._docTypeName.toString().c_str(), serialNum, lid); _indexWriter->remove(serialNum, lid); - if (immediateCommit) { - _indexWriter->commit(serialNum, onWriteDone); - } } void -SearchableFeedView::performIndexRemove(SerialNum serialNum, const LidVector &lidsToRemove, - bool immediateCommit, OnWriteDoneType onWriteDone) +SearchableFeedView::performIndexRemove(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone) { + (void) onWriteDone; assert(_writeService.index().isCurrentThread()); for (const auto lid : lidsToRemove) { VLOG(getDebugLevel(lid, nullptr), @@ -166,21 +156,18 @@ SearchableFeedView::performIndexRemove(SerialNum serialNum, const LidVector &lid _indexWriter->remove(serialNum, lid); } - if (immediateCommit) { - _indexWriter->commit(serialNum, onWriteDone); - } } void SearchableFeedView::removeIndexedFields(SerialNum serialNum, const LidVector &lidsToRemove, - bool immediateCommit, OnWriteDoneType onWriteDone) + OnWriteDoneType onWriteDone) { if (!hasIndexedFields()) return; _writeService.index().execute( makeLambdaTask([=]() { - performIndexRemove(serialNum, lidsToRemove, immediateCommit, onWriteDone); + performIndexRemove(serialNum, lidsToRemove, onWriteDone); })); } diff --git a/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.h b/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.h index 3265bc0ae70..944d383e06d 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.h +++ b/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.h @@ -34,38 +34,21 @@ private: bool hasIndexedFields() const { return _hasIndexedFields; } - void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, - bool immediateCommit, OnOperationDoneType onWriteDone); - - void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &doc, - bool immediateCommit, OnOperationDoneType onWriteDone); - void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, FutureDoc doc, - bool immediateCommit, OnOperationDoneType onWriteDone); - - void performIndexRemove(SerialNum serialNum, search::DocumentIdT lid, - bool immediateCommit, OnRemoveDoneType onWriteDone); - - void performIndexRemove(SerialNum serialNum, const LidVector &lidsToRemove, - bool immediateCommit, OnWriteDoneType onWriteDone); - + void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, OnOperationDoneType onWriteDone); + void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &doc, OnOperationDoneType onWriteDone); + void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, FutureDoc doc, OnOperationDoneType onWriteDone); + void performIndexRemove(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone); + void performIndexRemove(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone); void performIndexHeartBeat(SerialNum serialNum); - void internalDeleteBucket(const DeleteBucketOperation &delOp) override; void performSync(); void heartBeatIndexedFields(SerialNum serialNum) override; - void putIndexedFields(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &newDoc, - bool immediateCommit, OnOperationDoneType onWriteDone) override; - - void updateIndexedFields(SerialNum serialNum, search::DocumentIdT lid, FutureDoc newDoc, - bool immediateCommit, OnOperationDoneType onWriteDone) override; - - void removeIndexedFields(SerialNum serialNum, search::DocumentIdT lid, - bool immediateCommit, OnRemoveDoneType onWriteDone) override; - - void removeIndexedFields(SerialNum serialNum, const LidVector &lidsToRemove, - bool immediateCommit, OnWriteDoneType onWriteDone) override; + void putIndexedFields(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &newDoc, OnOperationDoneType onWriteDone) override; + void updateIndexedFields(SerialNum serialNum, search::DocumentIdT lid, FutureDoc newDoc, OnOperationDoneType onWriteDone) override; + void removeIndexedFields(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone) override; + void removeIndexedFields(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone) override; void performIndexForceCommit(SerialNum serialNum, OnForceCommitDoneType onCommitDone); void internalForceCommit(SerialNum serialNum, OnForceCommitDoneType onCommitDone) override; diff --git a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp index 3db55cf6755..186c321d920 100644 --- a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp @@ -288,10 +288,10 @@ StoreOnlyFeedView::considerEarlyAck(FeedToken & token) } void -StoreOnlyFeedView::putAttributes(SerialNum, Lid, const Document &, bool, OnPutDoneType) {} +StoreOnlyFeedView::putAttributes(SerialNum, Lid, const Document &, OnPutDoneType) {} void -StoreOnlyFeedView::putIndexedFields(SerialNum, Lid, const Document::SP &, bool, OnOperationDoneType) {} +StoreOnlyFeedView::putIndexedFields(SerialNum, Lid, const Document::SP &, OnOperationDoneType) {} void StoreOnlyFeedView::preparePut(PutOperation &putOp) @@ -334,15 +334,14 @@ StoreOnlyFeedView::internalPut(FeedToken token, const PutOperation &putOp) bool docAlreadyExists = putOp.getValidPrevDbdId(_params._subDbId); if (putOp.getValidDbdId(_params._subDbId)) { - bool immediateCommit = needImmediateCommit(); const document::GlobalId &gid = docId.getGlobalId(); std::shared_ptr<PutDoneContext> onWriteDone = createPutDoneContext(std::move(token), std::move(uncommitted), _gidToLidChangeHandler, doc, gid, putOp.getLid(), serialNum, putOp.changedDbdId() && useDocumentMetaStore(serialNum)); putSummary(serialNum, putOp.getLid(), doc, onWriteDone); - putAttributes(serialNum, putOp.getLid(), *doc, immediateCommit, onWriteDone); - putIndexedFields(serialNum, putOp.getLid(), doc, immediateCommit, onWriteDone); + putAttributes(serialNum, putOp.getLid(), *doc, onWriteDone); + putIndexedFields(serialNum, putOp.getLid(), doc, onWriteDone); } if (docAlreadyExists && putOp.changedDbdId()) { assert(!putOp.getValidDbdId(_params._subDbId)); @@ -369,7 +368,7 @@ void StoreOnlyFeedView::heartBeatAttributes(SerialNum ) {} void -StoreOnlyFeedView::updateAttributes(SerialNum, Lid, const DocumentUpdate & upd, bool, +StoreOnlyFeedView::updateAttributes(SerialNum, Lid, const DocumentUpdate & upd, OnOperationDoneType, IFieldUpdateCallback & onUpdate) { for (const auto & fieldUpdate : upd.getUpdates()) { @@ -378,12 +377,12 @@ StoreOnlyFeedView::updateAttributes(SerialNum, Lid, const DocumentUpdate & upd, } void -StoreOnlyFeedView::updateAttributes(SerialNum, Lid, FutureDoc, bool, OnOperationDoneType) +StoreOnlyFeedView::updateAttributes(SerialNum, Lid, FutureDoc, OnOperationDoneType) { } void -StoreOnlyFeedView::updateIndexedFields(SerialNum, Lid, FutureDoc, bool, OnOperationDoneType) +StoreOnlyFeedView::updateIndexedFields(SerialNum, Lid, FutureDoc, OnOperationDoneType) { } @@ -495,10 +494,9 @@ StoreOnlyFeedView::internalUpdate(FeedToken token, const UpdateOperation &updOp) auto uncommitted = get_pending_lid_token(updOp); considerEarlyAck(token); - bool immediateCommit = needImmediateCommit(); auto onWriteDone = createUpdateDoneContext(std::move(token), std::move(uncommitted), updOp.getUpdate()); UpdateScope updateScope(*_schema, upd); - updateAttributes(serialNum, lid, upd, immediateCommit, onWriteDone, updateScope); + updateAttributes(serialNum, lid, upd, onWriteDone, updateScope); if (updateScope.hasIndexOrNonAttributeFields()) { PromisedDoc promisedDoc; @@ -506,7 +504,7 @@ StoreOnlyFeedView::internalUpdate(FeedToken token, const UpdateOperation &updOp) onWriteDone->setDocument(futureDoc); _pendingLidsForDocStore.waitComplete(lid); if (updateScope._indexedFields) { - updateIndexedFields(serialNum, lid, futureDoc, immediateCommit, onWriteDone); + updateIndexedFields(serialNum, lid, futureDoc, onWriteDone); } PromisedStream promisedStream; FutureStream futureStream = promisedStream.get_future(); @@ -522,7 +520,7 @@ StoreOnlyFeedView::internalUpdate(FeedToken token, const UpdateOperation &updOp) makeUpdatedDocument(serialNum, lid, *upd, onWriteDone, std::move(promisedDoc), std::move(promisedStream)); })); - updateAttributes(serialNum, lid, std::move(futureDoc), immediateCommit, onWriteDone); + updateAttributes(serialNum, lid, std::move(futureDoc), onWriteDone); } } @@ -576,10 +574,10 @@ StoreOnlyFeedView::lookupDocId(const DocumentId &docId, Lid &lid) const } void -StoreOnlyFeedView::removeAttributes(SerialNum, Lid, bool, OnRemoveDoneType) {} +StoreOnlyFeedView::removeAttributes(SerialNum, Lid, OnRemoveDoneType) {} void -StoreOnlyFeedView::removeIndexedFields(SerialNum, Lid, bool, OnRemoveDoneType) {} +StoreOnlyFeedView::removeIndexedFields(SerialNum, Lid, OnRemoveDoneType) {} void StoreOnlyFeedView::prepareRemove(RemoveOperation &rmOp) @@ -666,9 +664,8 @@ StoreOnlyFeedView::internalRemove(FeedToken token, IPendingLidTracker::Token unc std::move(pendingNotifyRemoveDone), (explicitReuseLid ? lid : 0u), std::move(moveDoneCtx)); removeSummary(serialNum, lid, onWriteDone); - bool immediateCommit = needImmediateCommit(); - removeAttributes(serialNum, lid, immediateCommit, onWriteDone); - removeIndexedFields(serialNum, lid, immediateCommit, onWriteDone); + removeAttributes(serialNum, lid, onWriteDone); + removeIndexedFields(serialNum, lid, onWriteDone); } PendingNotifyRemoveDone @@ -699,14 +696,13 @@ StoreOnlyFeedView::adjustMetaStore(const DocumentOperation &op, const GlobalId & } void -StoreOnlyFeedView::removeAttributes(SerialNum, const LidVector &, bool , OnWriteDoneType ) {} +StoreOnlyFeedView::removeAttributes(SerialNum, const LidVector &, OnWriteDoneType ) {} void -StoreOnlyFeedView::removeIndexedFields(SerialNum , const LidVector &, bool , OnWriteDoneType ) {} +StoreOnlyFeedView::removeIndexedFields(SerialNum , const LidVector &, OnWriteDoneType ) {} size_t -StoreOnlyFeedView::removeDocuments(const RemoveDocumentsOperation &op, bool remove_index_and_attributes, - bool immediateCommit) +StoreOnlyFeedView::removeDocuments(const RemoveDocumentsOperation &op, bool remove_index_and_attributes) { const SerialNum serialNum = op.getSerialNum(); const LidVectorContext::SP &ctx = op.getLidsToRemove(_params._subDbId); @@ -744,8 +740,8 @@ StoreOnlyFeedView::removeDocuments(const RemoveDocumentsOperation &op, bool remo onWriteDone = std::make_shared<RemoveBatchDoneContext>(_writeService.master(), std::move(removeBatchDoneTask), _gidToLidChangeHandler, std::move(gidsToRemove), serialNum); if (remove_index_and_attributes) { - removeIndexedFields(serialNum, lidsToRemove, immediateCommit, onWriteDone); - removeAttributes(serialNum, lidsToRemove, immediateCommit, onWriteDone); + removeIndexedFields(serialNum, lidsToRemove, onWriteDone); + removeAttributes(serialNum, lidsToRemove, onWriteDone); } if (useDocumentStore(serialNum + 1)) { for (const auto &lid : lidsToRemove) { @@ -779,8 +775,7 @@ StoreOnlyFeedView::handleDeleteBucket(const DeleteBucketOperation &delOp) void StoreOnlyFeedView::internalDeleteBucket(const DeleteBucketOperation &delOp) { - bool immediateCommit = needImmediateCommit(); - size_t rm_count = removeDocuments(delOp, true, immediateCommit); + size_t rm_count = removeDocuments(delOp, true); LOG(debug, "internalDeleteBucket(): docType(%s), bucket(%s), lidsToRemove(%zu)", _params._docTypeName.toString().c_str(), delOp.getBucketId().toString().c_str(), rm_count); } @@ -818,15 +813,14 @@ StoreOnlyFeedView::handleMove(const MoveOperation &moveOp, IDestructorCallback:: PendingNotifyRemoveDone pendingNotifyRemoveDone = adjustMetaStore(moveOp, docId.getGlobalId(), docId); bool docAlreadyExists = moveOp.getValidPrevDbdId(_params._subDbId); if (moveOp.getValidDbdId(_params._subDbId)) { - bool immediateCommit = needImmediateCommit(); const document::GlobalId &gid = docId.getGlobalId(); std::shared_ptr<PutDoneContext> onWriteDone = createPutDoneContext(FeedToken(), _pendingLidsForCommit->produce(moveOp.getLid()), _gidToLidChangeHandler, doc, gid, moveOp.getLid(), serialNum, moveOp.changedDbdId() && useDocumentMetaStore(serialNum), doneCtx); putSummary(serialNum, moveOp.getLid(), doc, onWriteDone); - putAttributes(serialNum, moveOp.getLid(), *doc, immediateCommit, onWriteDone); - putIndexedFields(serialNum, moveOp.getLid(), doc, immediateCommit, onWriteDone); + putAttributes(serialNum, moveOp.getLid(), *doc, onWriteDone); + putIndexedFields(serialNum, moveOp.getLid(), doc, onWriteDone); } if (docAlreadyExists && moveOp.changedDbdId()) { internalRemove(FeedToken(), _pendingLidsForCommit->produce(moveOp.getPrevLid()), serialNum, std::move(pendingNotifyRemoveDone), moveOp.getPrevLid(), doneCtx); @@ -853,7 +847,7 @@ handlePruneRemovedDocuments(const PruneRemovedDocumentsOperation &pruneOp) { assert(_params._subDbType == SubDbType::REMOVED); assert(pruneOp.getSubDbId() == _params._subDbId); - uint32_t rm_count = removeDocuments(pruneOp, false, false); + uint32_t rm_count = removeDocuments(pruneOp, false); LOG(debug, "MinimalFeedView::handlePruneRemovedDocuments called, doctype(%s) %u lids pruned, limit %u", _params._docTypeName.toString().c_str(), rm_count, diff --git a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h index 167b246ec0b..da1459d521c 100644 --- a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h +++ b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h @@ -181,8 +181,7 @@ private: // Removes documents from meta store and document store. // returns the number of documents removed. - size_t removeDocuments(const RemoveDocumentsOperation &op, bool remove_index_and_attribute_fields, - bool immediateCommit); + size_t removeDocuments(const RemoveDocumentsOperation &op, bool remove_index_and_attribute_fields); void internalRemove(FeedToken token, IPendingLidTracker::Token uncommitted, SerialNum serialNum, PendingNotifyRemoveDone &&pendingNotifyRemoveDone, @@ -202,30 +201,20 @@ protected: virtual void heartBeatAttributes(SerialNum serialNum); private: - virtual void putAttributes(SerialNum serialNum, Lid lid, const Document &doc, - bool immediateCommit, OnPutDoneType onWriteDone); - - virtual void putIndexedFields(SerialNum serialNum, Lid lid, const DocumentSP &newDoc, - bool immediateCommit, OnOperationDoneType onWriteDone); + virtual void putAttributes(SerialNum serialNum, Lid lid, const Document &doc, OnPutDoneType onWriteDone); + virtual void putIndexedFields(SerialNum serialNum, Lid lid, const DocumentSP &newDoc, OnOperationDoneType onWriteDone); virtual void updateAttributes(SerialNum serialNum, Lid lid, const DocumentUpdate &upd, - bool immediateCommit, OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate); - - virtual void updateAttributes(SerialNum serialNum, Lid lid, FutureDoc doc, - bool immediateCommit, OnOperationDoneType onWriteDone); + OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate); - virtual void updateIndexedFields(SerialNum serialNum, Lid lid, FutureDoc doc, - bool immediateCommit, OnOperationDoneType onWriteDone); - - virtual void removeAttributes(SerialNum serialNum, Lid lid, bool immediateCommit, OnRemoveDoneType onWriteDone); - virtual void removeIndexedFields(SerialNum serialNum, Lid lid, bool immediateCommit, OnRemoveDoneType onWriteDone); + virtual void updateAttributes(SerialNum serialNum, Lid lid, FutureDoc doc, OnOperationDoneType onWriteDone); + virtual void updateIndexedFields(SerialNum serialNum, Lid lid, FutureDoc doc, OnOperationDoneType onWriteDone); + virtual void removeAttributes(SerialNum serialNum, Lid lid, OnRemoveDoneType onWriteDone); + virtual void removeIndexedFields(SerialNum serialNum, Lid lid, OnRemoveDoneType onWriteDone); protected: - virtual void removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove, - bool immediateCommit, OnWriteDoneType onWriteDone); - - virtual void removeIndexedFields(SerialNum serialNum, const LidVector &lidsToRemove, - bool immediateCommit, OnWriteDoneType onWriteDone); + virtual void removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone); + virtual void removeIndexedFields(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone); virtual void internalForceCommit(SerialNum serialNum, OnForceCommitDoneType onCommitDone); public: StoreOnlyFeedView(const Context &ctx, const PersistentParams ¶ms); diff --git a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp index 9142a03ab85..9de47b4a8a9 100644 --- a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp +++ b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp @@ -60,9 +60,17 @@ spi::LoadType defaultLoadType(0, "default"); struct TestFileStorComponents; +document::Bucket +make_bucket_for_doc(const document::DocumentId& docid) +{ + document::BucketIdFactory factory; + document::BucketId bucket_id(16, factory.getBucketId(docid).getRawId()); + return makeDocumentBucket(bucket_id); +} + } -struct FileStorManagerTest : Test{ +struct FileStorTestBase : Test { enum {LONG_WAITTIME=60}; unique_ptr<TestServiceLayerApp> _node; std::unique_ptr<vdstestlib::DirConfig> config; @@ -71,13 +79,13 @@ struct FileStorManagerTest : Test{ const uint32_t _waitTime; const document::DocumentType* _testdoctype1; - FileStorManagerTest() : _node(), _waitTime(LONG_WAITTIME) {} + FileStorTestBase() : _node(), _waitTime(LONG_WAITTIME) {} + ~FileStorTestBase(); void SetUp() override; void TearDown() override; - void createBucket(document::BucketId bid, uint16_t disk) - { + void createBucket(document::BucketId bid, uint16_t disk) { spi::Context context(defaultLoadType, spi::Priority(0), spi::Trace::TraceLevel(0)); assert(disk == 0u); _node->getPersistenceProvider().createBucket(makeSpiBucket(bid), context); @@ -88,11 +96,29 @@ struct FileStorManagerTest : Test{ entry.write(); } - document::Document::UP createDocument(const std::string& content, const std::string& id) - { + document::Document::UP createDocument(const std::string& content, const std::string& id) { return _node->getTestDocMan().createDocument(content, id); } + std::shared_ptr<api::PutCommand> make_put_command(StorageMessage::Priority pri = 20, + const std::string& docid = "id:foo:testdoctype1::bar", + Timestamp timestamp = 100) { + Document::SP doc(createDocument("my content", docid).release()); + auto bucket = make_bucket_for_doc(doc->getId()); + auto cmd = std::make_shared<api::PutCommand>(bucket, std::move(doc), timestamp); + cmd->setPriority(pri); + return cmd; + } + + std::shared_ptr<api::GetCommand> make_get_command(StorageMessage::Priority pri, + const std::string& docid = "id:foo:testdoctype1::bar") { + document::DocumentId did(docid); + auto bucket = make_bucket_for_doc(did); + auto cmd = std::make_shared<api::GetCommand>(bucket, did, document::AllFields::NAME); + cmd->setPriority(pri); + return cmd; + } + bool ownsBucket(uint16_t distributorIndex, const document::BucketId& bucket) const { @@ -163,10 +189,12 @@ struct FileStorManagerTest : Test{ const Metric& metric); auto& thread_metrics_of(FileStorManager& manager) { - return manager._metrics->disk->threads[0]; + return manager.get_metrics().disk->threads[0]; } }; +FileStorTestBase::~FileStorTestBase() = default; + std::string findFile(const std::string& path, const std::string& file) { FastOS_DirectoryScan dirScan(path.c_str()); while (dirScan.ReadNext()) { @@ -207,7 +235,7 @@ struct TestFileStorComponents { DummyStorageLink top; FileStorManager* manager; - explicit TestFileStorComponents(FileStorManagerTest& test, + explicit TestFileStorComponents(FileStorTestBase& test, bool use_small_config = false) : manager(new FileStorManager((use_small_config ? test.smallConfig : test.config)->getConfigId(), test._node->getPersistenceProvider(), @@ -227,7 +255,7 @@ struct FileStorHandlerComponents { FileStorMetrics metrics; std::unique_ptr<FileStorHandler> filestorHandler; - FileStorHandlerComponents(FileStorManagerTest& test, uint32_t threadsPerDisk = 1) + FileStorHandlerComponents(FileStorTestBase& test, uint32_t threadsPerDisk = 1) : top(), dummyManager(new DummyStorageLink), messageSender(*dummyManager), @@ -253,7 +281,7 @@ struct PersistenceHandlerComponents : public FileStorHandlerComponents { BucketOwnershipNotifier bucketOwnershipNotifier; std::unique_ptr<PersistenceHandler> persistenceHandler; - PersistenceHandlerComponents(FileStorManagerTest& test) + PersistenceHandlerComponents(FileStorTestBase& test) : FileStorHandlerComponents(test), component(test._node->getComponentRegister(), "test"), bucketOwnershipNotifier(component, messageSender), @@ -277,17 +305,21 @@ PersistenceHandlerComponents::~PersistenceHandlerComponents() = default; } void -FileStorManagerTest::SetUp() +FileStorTestBase::SetUp() { setupDisks(); } void -FileStorManagerTest::TearDown() +FileStorTestBase::TearDown() { _node.reset(0); } +struct FileStorManagerTest : public FileStorTestBase { + +}; + TEST_F(FileStorManagerTest, header_only_put) { TestFileStorComponents c(*this); auto& top = c.top; @@ -947,10 +979,10 @@ TEST_F(FileStorManagerTest, split_single_group) { } void -FileStorManagerTest::putDoc(DummyStorageLink& top, - FileStorHandler& filestorHandler, - const document::BucketId& target, - uint32_t docNum) +FileStorTestBase::putDoc(DummyStorageLink& top, + FileStorHandler& filestorHandler, + const document::BucketId& target, + uint32_t docNum) { api::StorageMessageAddress address("storage", lib::NodeType::STORAGE, 3); spi::Context context(defaultLoadType, spi::Priority(0), @@ -1838,7 +1870,7 @@ TEST_F(FileStorManagerTest, create_bucket_sets_active_flag_in_database_and_reply } template <typename Metric> -void FileStorManagerTest::assert_request_size_set(TestFileStorComponents& c, std::shared_ptr<api::StorageMessage> cmd, const Metric& metric) { +void FileStorTestBase::assert_request_size_set(TestFileStorComponents& c, std::shared_ptr<api::StorageMessage> cmd, const Metric& metric) { api::StorageMessageAddress address("storage", lib::NodeType::STORAGE, 3); cmd->setApproxByteSize(54321); cmd->setAddress(address); @@ -1965,4 +1997,97 @@ TEST_F(FileStorManagerTest, bucket_db_is_populated_from_provider_when_initialize EXPECT_EQ(reported_state->getState(), lib::State::UP); } +struct FileStorHandlerTest : public FileStorTestBase { + std::unique_ptr<FileStorHandlerComponents> c; + FileStorHandler* handler; + FileStorHandlerTest() + : FileStorTestBase(), + c(), + handler() + {} + void SetUp() override { + FileStorTestBase::SetUp(); + c = std::make_unique<FileStorHandlerComponents>(*this); + handler = c->filestorHandler.get(); + } + FileStorHandler::LockedMessage get_next_message() { + return handler->getNextMessage(0); + } +}; + +void +expect_async_message(StorageMessage::Priority exp_pri, + const FileStorHandler::ScheduleAsyncResult& result) +{ + EXPECT_TRUE(result.was_scheduled()); + ASSERT_TRUE(result.has_async_message()); + EXPECT_EQ(exp_pri, result.async_message().second->getPriority()); +} + +void +expect_empty_async_message(const FileStorHandler::ScheduleAsyncResult& result) +{ + EXPECT_TRUE(result.was_scheduled()); + EXPECT_FALSE(result.has_async_message()); +} + +TEST_F(FileStorHandlerTest, message_not_scheduled_if_handler_is_closed) +{ + handler->setDiskState(FileStorHandler::DiskState::CLOSED); + auto result = handler->schedule_and_get_next_async_message(make_put_command()); + EXPECT_FALSE(result.was_scheduled()); +} + +TEST_F(FileStorHandlerTest, no_async_message_returned_if_handler_is_paused) +{ + auto guard = handler->pause(); + auto result = handler->schedule_and_get_next_async_message(make_put_command()); + expect_empty_async_message(result); +} + +TEST_F(FileStorHandlerTest, async_message_with_lowest_pri_returned_on_schedule) +{ + handler->schedule(make_put_command(20)); + handler->schedule(make_put_command(40)); + { + auto result = handler->schedule_and_get_next_async_message(make_put_command(30)); + expect_async_message(20, result); + } + EXPECT_EQ(30, get_next_message().second->getPriority()); + EXPECT_EQ(40, get_next_message().second->getPriority()); +} + +TEST_F(FileStorHandlerTest, no_async_message_returned_if_lowest_pri_message_is_not_async) +{ + // GET is not an async message. + handler->schedule(make_get_command(20)); + + auto result = handler->schedule_and_get_next_async_message(make_put_command(30)); + expect_empty_async_message(result); + + EXPECT_EQ(20, get_next_message().second->getPriority()); + EXPECT_EQ(30, get_next_message().second->getPriority()); +} + +TEST_F(FileStorHandlerTest, inhibited_operations_are_skipped) +{ + std::string docid_a = "id:foo:testdoctype1::a"; + std::string docid_b = "id:foo:testdoctype1::b"; + handler->schedule(make_put_command(20, docid_a)); + { + auto locked_msg = get_next_message(); + { + // Bucket for docid_a is locked and put command for same bucket is inhibited. + auto result = handler->schedule_and_get_next_async_message(make_put_command(30, docid_a)); + expect_empty_async_message(result); + } + { + // Put command for another bucket is ok. + auto result = handler->schedule_and_get_next_async_message(make_put_command(40, docid_b)); + expect_async_message(40, result); + } + } + EXPECT_EQ(30, get_next_message().second->getPriority()); +} + } // storage diff --git a/storage/src/vespa/storage/bucketdb/btree_lockable_map.h b/storage/src/vespa/storage/bucketdb/btree_lockable_map.h index ea3a7838d43..6e42a721732 100644 --- a/storage/src/vespa/storage/bucketdb/btree_lockable_map.h +++ b/storage/src/vespa/storage/bucketdb/btree_lockable_map.h @@ -37,7 +37,7 @@ public: using BucketId = document::BucketId; BTreeLockableMap(); - ~BTreeLockableMap(); + ~BTreeLockableMap() override; bool operator==(const BTreeLockableMap& other) const; bool operator!=(const BTreeLockableMap& other) const { diff --git a/storage/src/vespa/storage/persistence/asynchandler.cpp b/storage/src/vespa/storage/persistence/asynchandler.cpp index 1d1f5caf673..5344553dd45 100644 --- a/storage/src/vespa/storage/persistence/asynchandler.cpp +++ b/storage/src/vespa/storage/persistence/asynchandler.cpp @@ -182,6 +182,19 @@ AsyncHandler::handleRemove(api::RemoveCommand& cmd, MessageTracker::UP trackerUP } bool +AsyncHandler::is_async_message(api::MessageType::Id type_id) noexcept +{ + switch (type_id) { + case api::MessageType::PUT_ID: + case api::MessageType::UPDATE_ID: + case api::MessageType::REMOVE_ID: + return true; + default: + return false; + } +} + +bool AsyncHandler::tasConditionExists(const api::TestAndSetCommand & cmd) { return cmd.getCondition().isPresent(); } diff --git a/storage/src/vespa/storage/persistence/asynchandler.h b/storage/src/vespa/storage/persistence/asynchandler.h index c25f2ea0be6..92bf72e7c51 100644 --- a/storage/src/vespa/storage/persistence/asynchandler.h +++ b/storage/src/vespa/storage/persistence/asynchandler.h @@ -25,6 +25,7 @@ public: MessageTrackerUP handlePut(api::PutCommand& cmd, MessageTrackerUP tracker) const; MessageTrackerUP handleRemove(api::RemoveCommand& cmd, MessageTrackerUP tracker) const; MessageTrackerUP handleUpdate(api::UpdateCommand& cmd, MessageTrackerUP tracker) const; + static bool is_async_message(api::MessageType::Id type_id) noexcept; private: static bool tasConditionExists(const api::TestAndSetCommand & cmd); bool tasConditionMatches(const api::TestAndSetCommand & cmd, MessageTracker & tracker, diff --git a/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h b/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h index 44e768c9db7..aafc87aa84f 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h +++ b/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h @@ -58,6 +58,30 @@ public: }; using LockedMessage = std::pair<BucketLockInterface::SP, api::StorageMessage::SP>; + class ScheduleAsyncResult { + private: + bool _was_scheduled; + LockedMessage _async_message; + + public: + ScheduleAsyncResult() : _was_scheduled(false), _async_message() {} + explicit ScheduleAsyncResult(LockedMessage&& async_message_in) + : _was_scheduled(true), + _async_message(std::move(async_message_in)) + {} + bool was_scheduled() const { + return _was_scheduled; + } + bool has_async_message() const { + return _async_message.first.get() != nullptr; + } + const LockedMessage& async_message() const { + return _async_message; + } + LockedMessage&& release_async_message() { + return std::move(_async_message); + } + }; enum DiskState { AVAILABLE, @@ -104,6 +128,11 @@ public: virtual bool schedule(const std::shared_ptr<api::StorageMessage>&) = 0; /** + * Schedule the given message to be processed and return the next async message to process (if any). + */ + virtual ScheduleAsyncResult schedule_and_get_next_async_message(const std::shared_ptr<api::StorageMessage>& msg) = 0; + + /** * Used by file stor threads to get their next message to process. * * @param stripe The stripe to get messages for diff --git a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp index 0c34a421c06..14074b65c5c 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp +++ b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp @@ -10,6 +10,7 @@ #include <vespa/storage/common/statusmessages.h> #include <vespa/storage/common/bucketoperationlogger.h> #include <vespa/storage/common/messagebucket.h> +#include <vespa/storage/persistence/asynchandler.h> #include <vespa/storage/persistence/messages.h> #include <vespa/storageapi/message/stat.h> #include <vespa/vespalib/stllike/hash_map.hpp> @@ -258,6 +259,16 @@ FileStorHandlerImpl::schedule(const std::shared_ptr<api::StorageMessage>& msg) return false; } +FileStorHandler::ScheduleAsyncResult +FileStorHandlerImpl::schedule_and_get_next_async_message(const std::shared_ptr<api::StorageMessage>& msg) +{ + if (getState() == FileStorHandler::AVAILABLE) { + document::Bucket bucket = getStorageMessageBucket(*msg); + return ScheduleAsyncResult(stripe(bucket).schedule_and_get_next_async_message(MessageEntry(msg, bucket))); + } + return {}; +} + bool FileStorHandlerImpl::messageMayBeAborted(const api::StorageMessage& msg) { @@ -911,6 +922,24 @@ FileStorHandlerImpl::Stripe::getNextMessage(vespalib::duration timeout) } FileStorHandler::LockedMessage +FileStorHandlerImpl::Stripe::get_next_async_message(monitor_guard& guard) +{ + if (_owner.isClosed() || _owner.isPaused()) { + return {}; + } + PriorityIdx& idx(bmi::get<1>(*_queue)); + PriorityIdx::iterator iter(idx.begin()), end(idx.end()); + + while ((iter != end) && operationIsInhibited(guard, iter->_bucket, *iter->_command)) { + ++iter; + } + if ((iter != end) && AsyncHandler::is_async_message(iter->_command->getType().getId())) { + return getMessage(guard, idx, iter); + } + return {}; +} + +FileStorHandler::LockedMessage FileStorHandlerImpl::Stripe::getMessage(monitor_guard & guard, PriorityIdx & idx, PriorityIdx::iterator iter) { api::StorageMessage & m(*iter->_command); @@ -989,6 +1018,14 @@ bool FileStorHandlerImpl::Stripe::schedule(MessageEntry messageEntry) return true; } +FileStorHandler::LockedMessage +FileStorHandlerImpl::Stripe::schedule_and_get_next_async_message(MessageEntry entry) +{ + std::unique_lock guard(*_lock); + _queue->emplace_back(std::move(entry)); + return get_next_async_message(guard); +} + void FileStorHandlerImpl::Stripe::flush() { diff --git a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h index 6aac8b0474b..549de164229 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h +++ b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h @@ -101,6 +101,7 @@ public: ~Stripe(); void flush(); bool schedule(MessageEntry messageEntry); + FileStorHandler::LockedMessage schedule_and_get_next_async_message(MessageEntry entry); void waitUntilNoLocks() const; void abort(std::vector<std::shared_ptr<api::StorageReply>> & aborted, const AbortBucketOperationsCommand& cmd); void waitInactive(const AbortBucketOperationsCommand& cmd) const; @@ -137,6 +138,8 @@ public: void setMetrics(FileStorStripeMetrics * metrics) { _metrics = metrics; } private: bool hasActive(monitor_guard & monitor, const AbortBucketOperationsCommand& cmd) const; + FileStorHandler::LockedMessage get_next_async_message(monitor_guard& guard); + // Precondition: the bucket used by `iter`s operation is not locked in a way that conflicts // with its locking requirements. FileStorHandler::LockedMessage getMessage(monitor_guard & guard, PriorityIdx & idx, @@ -184,6 +187,7 @@ public: DiskState getDiskState() const override; void close() override; bool schedule(const std::shared_ptr<api::StorageMessage>&) override; + ScheduleAsyncResult schedule_and_get_next_async_message(const std::shared_ptr<api::StorageMessage>& msg) override; FileStorHandler::LockedMessage getNextMessage(uint32_t stripeId) override; diff --git a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp index 188523af38d..2653391ecfa 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp +++ b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp @@ -48,6 +48,7 @@ FileStorManager(const config::ConfigUri & configUri, spi::PersistenceProvider& p _configFetcher(configUri.getContext()), _threadLockCheckInterval(60), _failDiskOnError(false), + _use_async_message_handling_on_schedule(false), _metrics(std::make_unique<FileStorMetrics>(_component.getLoadTypes()->getMetricLoadTypes())), _closed(false), _lock() @@ -151,6 +152,7 @@ FileStorManager::configure(std::unique_ptr<vespa::config::content::StorFilestorC _threadLockCheckInterval = config->diskOperationTimeout; _failDiskOnError = (config->failDiskAfterErrorCount > 0); + _use_async_message_handling_on_schedule = config->useAsyncMessageHandlingOnSchedule; if (!liveUpdate) { _config = std::move(config); @@ -258,10 +260,20 @@ FileStorManager::handlePersistenceMessage(const shared_ptr<api::StorageMessage>& api::ReturnCode errorCode(api::ReturnCode::OK); LOG(spam, "Received %s. Attempting to queue it.", msg->getType().getName().c_str()); - if (_filestorHandler->schedule(msg)) { - LOG(spam, "Received persistence message %s. Queued it to disk", - msg->getType().getName().c_str()); - return true; + if (_use_async_message_handling_on_schedule) { + auto result = _filestorHandler->schedule_and_get_next_async_message(msg); + if (result.was_scheduled()) { + if (result.has_async_message()) { + getThreadLocalHandler().processLockedMessage(result.release_async_message()); + } + return true; + } + } else { + if (_filestorHandler->schedule(msg)) { + LOG(spam, "Received persistence message %s. Queued it to disk", + msg->getType().getName().c_str()); + return true; + } } switch (_filestorHandler->getDiskState()) { case FileStorHandler::DISABLED: diff --git a/storage/src/vespa/storage/persistence/filestorage/filestormanager.h b/storage/src/vespa/storage/persistence/filestorage/filestormanager.h index ee66bc7d77c..2953462dd1e 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestormanager.h +++ b/storage/src/vespa/storage/persistence/filestorage/filestormanager.h @@ -65,14 +65,13 @@ class FileStorManager : public StorageLinkQueued, config::ConfigFetcher _configFetcher; uint32_t _threadLockCheckInterval; // In seconds bool _failDiskOnError; + bool _use_async_message_handling_on_schedule; std::shared_ptr<FileStorMetrics> _metrics; std::unique_ptr<FileStorHandler> _filestorHandler; std::unique_ptr<vespalib::ISequencedTaskExecutor> _sequencedExecutor; bool _closed; std::mutex _lock; - friend struct FileStorManagerTest; - public: FileStorManager(const config::ConfigUri &, spi::PersistenceProvider&, ServiceLayerComponentRegister&, DoneInitializeHandler&); @@ -105,6 +104,8 @@ public: // yet at that point in time. void initialize_bucket_databases_from_provider(); + const FileStorMetrics& get_metrics() const { return *_metrics; } + private: void configure(std::unique_ptr<vespa::config::content::StorFilestorConfig> config) override; PersistenceHandler & createRegisteredHandler(const ServiceLayerComponent & component); diff --git a/storage/src/vespa/storage/persistence/mergehandler.cpp b/storage/src/vespa/storage/persistence/mergehandler.cpp index 4fe7333fb5f..c7c681a838b 100644 --- a/storage/src/vespa/storage/persistence/mergehandler.cpp +++ b/storage/src/vespa/storage/persistence/mergehandler.cpp @@ -403,9 +403,9 @@ MergeHandler::fetchLocalData( || (entries.empty() && alreadyFilled == 0)) { remainingSize -= entry->getSize(); + entries.push_back(std::move(entry)); LOG(spam, "Added %s, remainingSize is %u", entries.back()->toString().c_str(), remainingSize); - entries.push_back(std::move(entry)); } else { LOG(spam, "Adding %s would exceed chunk size limit of %u; " "not filling up any more diffs for current round", diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java index eaf83238145..33cb6d7d5d4 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java @@ -9,8 +9,10 @@ import com.yahoo.vespa.athenz.api.OktaAccessToken; import com.yahoo.vespa.athenz.api.OktaIdentityToken; import com.yahoo.vespa.athenz.client.common.ClientBase; import com.yahoo.vespa.athenz.client.zms.bindings.AccessResponseEntity; +import com.yahoo.vespa.athenz.client.zms.bindings.AssertionEntity; import com.yahoo.vespa.athenz.client.zms.bindings.DomainListResponseEntity; import com.yahoo.vespa.athenz.client.zms.bindings.MembershipResponseEntity; +import com.yahoo.vespa.athenz.client.zms.bindings.PolicyEntity; import com.yahoo.vespa.athenz.client.zms.bindings.ProviderResourceGroupRolesRequestEntity; import com.yahoo.vespa.athenz.client.zms.bindings.TenancyRequestEntity; import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; @@ -23,6 +25,8 @@ import javax.net.ssl.SSLContext; import java.net.URI; import java.util.Collections; import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import java.util.function.Supplier; @@ -149,6 +153,47 @@ public class DefaultZmsClient extends ClientBase implements ZmsClient { }); } + @Override + public void addPolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole) { + URI uri = zmsUrl.resolve(String.format("domain/%s/policy/%s/assertion", + athenzDomain.getName(), athenzPolicy)); + HttpUriRequest request = RequestBuilder.put() + .setUri(uri) + .setEntity(toJsonStringEntity(new AssertionEntity(athenzRole.toResourceNameString(), resourceName.toResourceNameString(), action))) + .build(); + execute(request, response -> readEntity(response, Void.class)); + } + + @Override + public boolean deletePolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole) { + URI uri = zmsUrl.resolve(String.format("domain/%s/policy/%s", + athenzDomain.getName(), athenzPolicy)); + HttpUriRequest request = RequestBuilder.get() + .setUri(uri) + .build(); + PolicyEntity policyEntity = execute(request, response -> readEntity(response, PolicyEntity.class)); + + OptionalInt assertionId = policyEntity.getAssertions().stream() + .filter(assertionEntity -> assertionEntity.getAction().equals(action) && + assertionEntity.getResource().equals(resourceName.toResourceNameString()) && + assertionEntity.getRole().equals(athenzRole.toResourceNameString())) + .mapToInt(AssertionEntity::getId).findFirst(); + + if (assertionId.isEmpty()) { + return false; + } + + uri = zmsUrl.resolve(String.format("domain/%s/policy/%s/assertion/%d", + athenzDomain.getName(), athenzPolicy, assertionId.getAsInt())); + + request = RequestBuilder.delete() + .setUri(uri) + .build(); + + execute(request, response -> readEntity(response, Void.class)); + return true; + } + private static Header createCookieHeaderWithOktaTokens(OktaIdentityToken identityToken, OktaAccessToken accessToken) { return new BasicHeader("Cookie", String.format("okta_at=%s; okta_it=%s", accessToken.token(), identityToken.token())); } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/ZmsClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/ZmsClient.java index 12762534bd4..c7f865a58bb 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/ZmsClient.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/ZmsClient.java @@ -38,5 +38,9 @@ public interface ZmsClient extends AutoCloseable { boolean hasAccess(AthenzResourceName resource, String action, AthenzIdentity identity); + void addPolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole); + + boolean deletePolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole); + void close(); } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/AssertionEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/AssertionEntity.java new file mode 100644 index 00000000000..824aa3b4606 --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/AssertionEntity.java @@ -0,0 +1,52 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.client.zms.bindings; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * @author olaa + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +public class AssertionEntity { + + private final String role; + private final String resource; + private final String action; + private final Integer id; + + + public AssertionEntity(String role, String resource, String action) { + this(role, resource, action, null); + } + + public AssertionEntity(@JsonProperty("role") String role, + @JsonProperty("resource") String resource, + @JsonProperty("action") String action, + @JsonProperty("id") Integer id) { + this.role = role; + this.resource = resource; + this.action = action; + this.id = id; + } + + public String getRole() { + return role; + } + + public String getResource() { + return resource; + } + + public String getAction() { + return action; + } + + @JsonIgnore + public int getId() { + return id; + } +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/PolicyEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/PolicyEntity.java new file mode 100644 index 00000000000..ebc0997cb09 --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/PolicyEntity.java @@ -0,0 +1,33 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.client.zms.bindings; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * @author olaa + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class PolicyEntity { + + @JsonInclude(JsonInclude.Include.NON_EMPTY) + private final List<AssertionEntity> assertions; + private final String name; + + public PolicyEntity(@JsonProperty("name") String name, + @JsonProperty("assertions") List<AssertionEntity> assertions) { + this.name = name; + this.assertions = assertions; + } + + public String getName() { + return name; + } + + public List<AssertionEntity> getAssertions() { + return assertions; + } +} diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java index 13ca774dc33..4fdac7b584a 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java @@ -244,13 +244,19 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler { @Override public void destroy() { executor.shutdown(); + Instant doom = clock.instant().plus(Duration.ofSeconds(20)); + while ( ! operations.isEmpty() && clock.instant().isBefore(doom)) + dispatchEnqueued(); + + if ( ! operations.isEmpty()) + log.log(WARNING, "Failed to empty request queue before shutdown timeout — " + operations.size() + " requests left"); + + asyncSession.destroy(); visits.values().forEach(VisitorSession::destroy); + try { - if ( ! executor.awaitTermination(10, TimeUnit.SECONDS)) { + if ( ! executor.awaitTermination(Duration.between(clock.instant(), doom).toMillis(), TimeUnit.MILLISECONDS)) executor.shutdownNow(); - if ( ! executor.awaitTermination(10, TimeUnit.SECONDS)) - log.log(WARNING, "Failed shutting down /document/v1 executor within 20 seconds"); - } } catch (InterruptedException e) { log.log(WARNING, "Interrupted waiting for /document/v1 executor to shut down"); @@ -729,13 +735,12 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler { jsonResponse.commit(Response.Status.PRECONDITION_FAILED); break; case INSUFFICIENT_STORAGE: - log.log(WARNING, "Insufficient storage left in cluster: " + response.getTextMessage()); jsonResponse.commit(Response.Status.INSUFFICIENT_STORAGE); break; default: log.log(WARNING, "Unexpected document API operation outcome '" + response.outcome() + "'"); case ERROR: - log.log(WARNING, "Exception performing document operation: " + response.getTextMessage()); + log.log(FINE, () -> "Exception performing document operation: " + response.getTextMessage()); jsonResponse.commit(Response.Status.INTERNAL_SERVER_ERROR); } } |