summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/pom.xml9
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java27
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java228
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java19
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java141
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java97
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java154
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java9
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java2
-rw-r--r--config-model/src/main/protobuf/onnx.proto464
-rw-r--r--config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd2
-rwxr-xr-xconfig-model/src/test/integration/onnx-model/files/create_dynamic_model.py12
-rwxr-xr-xconfig-model/src/test/integration/onnx-model/files/create_model.py37
-rwxr-xr-xconfig-model/src/test/integration/onnx-model/files/create_unbound_model.py12
-rw-r--r--config-model/src/test/integration/onnx-model/files/dynamic_model.onnx13
-rw-r--r--config-model/src/test/integration/onnx-model/files/model.onnx34
-rw-r--r--config-model/src/test/integration/onnx-model/files/summary_model.onnx34
-rw-r--r--config-model/src/test/integration/onnx-model/files/unbound_model.onnx11
-rw-r--r--config-model/src/test/integration/onnx-model/searchdefinitions/test.sd52
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java89
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java36
-rw-r--r--configdefinitions/src/vespa/stor-filestor.def8
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java2
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/ZmsClientMock.java11
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/Node.java15
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/NodeRepository.java5
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/HostRepairClient.java23
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/MockRepairClient.java33
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/RepairTicketReport.java63
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/package-info.java5
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java26
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java9
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java10
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java4
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirer.java5
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainer.java81
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java39
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/QuotaUsageTest.java21
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/NodeRepositoryMock.java10
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java8
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainerTest.java51
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json3
-rw-r--r--eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp11
-rw-r--r--fat-model-dependencies/pom.xml5
-rw-r--r--flags/src/main/java/com/yahoo/vespa/flags/Flags.java2
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java18
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcher.java2
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java2
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java30
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java10
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeCandidate.java19
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java50
-rw-r--r--searchcore/src/tests/proton/attribute/attribute_populator/attribute_populator_test.cpp4
-rw-r--r--searchcore/src/tests/proton/attribute/attribute_test.cpp62
-rw-r--r--searchcore/src/tests/proton/docsummary/docsummary.cpp29
-rw-r--r--searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp19
-rw-r--r--searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp83
-rw-r--r--searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp16
-rw-r--r--searchcore/src/vespa/searchcore/proton/attribute/attribute_populator.cpp8
-rw-r--r--searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp105
-rw-r--r--searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h21
-rw-r--r--searchcore/src/vespa/searchcore/proton/attribute/filter_attribute_manager.cpp2
-rw-r--r--searchcore/src/vespa/searchcore/proton/attribute/i_attribute_writer.h16
-rw-r--r--searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.cpp3
-rw-r--r--searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.h10
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.cpp27
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.h16
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.cpp51
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.h35
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp52
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h31
-rw-r--r--storage/src/tests/persistence/filestorage/filestormanagertest.cpp159
-rw-r--r--storage/src/vespa/storage/bucketdb/btree_lockable_map.h2
-rw-r--r--storage/src/vespa/storage/persistence/asynchandler.cpp13
-rw-r--r--storage/src/vespa/storage/persistence/asynchandler.h1
-rw-r--r--storage/src/vespa/storage/persistence/filestorage/filestorhandler.h29
-rw-r--r--storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp37
-rw-r--r--storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h4
-rw-r--r--storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp20
-rw-r--r--storage/src/vespa/storage/persistence/filestorage/filestormanager.h5
-rw-r--r--storage/src/vespa/storage/persistence/mergehandler.cpp2
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java45
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/ZmsClient.java4
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/AssertionEntity.java52
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/PolicyEntity.java33
-rw-r--r--vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java17
90 files changed, 2448 insertions, 644 deletions
diff --git a/config-model/pom.xml b/config-model/pom.xml
index 95e79fd09fb..c0751431d03 100644
--- a/config-model/pom.xml
+++ b/config-model/pom.xml
@@ -46,6 +46,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${protobuf.version}</version>
+ </dependency>
+ <dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<scope>provided</scope>
@@ -498,6 +503,10 @@
<updateReleaseInfo>true</updateReleaseInfo>
</configuration>
</plugin>
+ <plugin>
+ <groupId>com.github.os72</groupId>
+ <artifactId>protoc-jar-maven-plugin</artifactId>
+ </plugin>
</plugins>
</build>
</project>
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 4011ce43841..b153ff62e7d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -158,6 +159,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments())));
}
+ // A reference to an ONNX model?
+ Optional<TensorType> onnxFeatureType = onnxFeatureType(reference);
+ if (onnxFeatureType.isPresent()) {
+ return onnxFeatureType.get();
+ }
+
// A reference to a feature which returns a tensor?
Optional<TensorType> featureTensorType = tensorFeatureType(reference);
if (featureTensorType.isPresent()) {
@@ -210,6 +217,26 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return Optional.of(function);
}
+ private Optional<TensorType> onnxFeatureType(Reference reference) {
+ if ( ! reference.name().equals("onnxModel"))
+ return Optional.empty();
+
+ if ( ! featureTypes.containsKey(reference)) {
+ String configOrFileName = reference.arguments().expressions().get(0).toString();
+
+ // Look up standardized format as added in RankProfile
+ String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
+ String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);
+
+ reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
+ if ( ! featureTypes.containsKey(reference)) {
+ throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
+ }
+ }
+
+ return Optional.of(featureTypes.get(reference));
+ }
+
/**
* There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet.
* This returns the type of those features if this is a reference to either of them, or empty otherwise.
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index c2fb2107604..5e8b8579ee6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -2,14 +2,22 @@
package com.yahoo.searchdefinition;
import com.yahoo.config.FileReference;
+import com.yahoo.path.Path;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.AbstractService;
import com.yahoo.vespa.model.utils.FileSender;
+import onnx.Onnx;
-import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.List;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -21,16 +29,16 @@ public class OnnxModel {
public enum PathType {FILE, URI};
private final String name;
+ private PathType pathType = PathType.FILE;
private String path = null;
private String fileReference = "";
- private List<OnnxNameMapping> inputMap = new ArrayList<>();
- private List<OnnxNameMapping> outputMap = new ArrayList<>();
-
- public PathType getPathType() {
- return pathType;
- }
+ private String defaultOutput = null;
+ private Map<String, String> inputMap = new HashMap<>();
+ private Map<String, String> outputMap = new HashMap<>();
- private PathType pathType = PathType.FILE;
+ private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>();
+ private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>();
+ private Map<String, TensorType> vespaTypes = new HashMap<>();
public OnnxModel(String name) {
this.name = name;
@@ -49,21 +57,52 @@ public class OnnxModel {
}
public void setUri(String uri) {
- Objects.requireNonNull(uri, "uri cannot be null");
- this.path = uri;
- this.pathType = PathType.URI;
+ throw new IllegalArgumentException("URI for ONNX models are not currently supported");
+ }
+
+ public PathType getPathType() {
+ return pathType;
+ }
+
+ public void setDefaultOutput(String onnxName) {
+ Objects.requireNonNull(onnxName, "Name cannot be null");
+ this.defaultOutput = onnxName;
}
public void addInputNameMapping(String onnxName, String vespaName) {
+ addInputNameMapping(onnxName, vespaName, true);
+ }
+
+ public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- this.inputMap.add(new OnnxNameMapping(onnxName, vespaName));
+ if (overwrite || ! inputMap.containsKey(onnxName)) {
+ inputMap.put(onnxName, vespaName);
+ }
}
public void addOutputNameMapping(String onnxName, String vespaName) {
+ addOutputNameMapping(onnxName, vespaName, true);
+ }
+
+ public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- this.outputMap.add(new OnnxNameMapping(onnxName, vespaName));
+ if (overwrite || ! outputMap.containsKey(onnxName)) {
+ outputMap.put(onnxName, vespaName);
+ }
+ }
+
+ public void addInputType(String onnxName, Onnx.TypeProto type) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(type, "Tensor type cannot be null");
+ inputTypes.put(onnxName, type);
+ }
+
+ public void addOutputType(String onnxName, Onnx.TypeProto type) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(type, "Tensor type cannot be null");
+ outputTypes.put(onnxName, type);
}
/** Initiate sending of this constant to some services over file distribution */
@@ -76,11 +115,16 @@ public class OnnxModel {
public String getName() { return name; }
public String getFileName() { return path; }
+ public Path getFilePath() { return Path.fromString(path); }
public String getUri() { return path; }
public String getFileReference() { return fileReference; }
- public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); }
- public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); }
+ public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
+ public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
+
+ public String getDefaultOutput() {
+ return defaultOutput;
+ }
public void validate() {
if (path == null || path.isEmpty())
@@ -90,23 +134,151 @@ public class OnnxModel {
public String toString() {
StringBuilder b = new StringBuilder();
b.append("onnx-model '").append(name)
- .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
- .append("' with ref '").append(fileReference)
- .append("'");
+ .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
+ .append("' with ref '").append(fileReference)
+ .append("'");
return b.toString();
}
- public static class OnnxNameMapping {
- private String onnxName;
- private String vespaName;
+ /**
+ * Return the tensor type for an ONNX model for the given context.
+ * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output
+ * type depends on the input types for the given context (rank profile).
+ */
+ public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) {
+ Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName);
+ if (onnxOutputType == null) {
+ throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'");
+ }
+ if (allDimensionSizesAreKnown(onnxOutputType)) {
+ return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType));
+ }
+ return getTensorTypeWithUnknownDimensions(onnxOutputType, context);
+ }
+
+ private static boolean allDimensionSizesAreKnown(Onnx.TypeProto type) {
+ return type.getTensorType().getShape().getDimList().stream().noneMatch(d ->
+ (d.hasDimParam() && ! d.hasDimValue()) || d.getDimValue() == -1);
+ }
+
+ private TensorType getTensorTypeWithUnknownDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
+ long unboundSize = 0;
+ Map<String, Long> symbolicSizes = new HashMap<>();
+
+ for (String onnxInputName : inputTypes.keySet()) {
+ Onnx.TypeProto onnxType = inputTypes.get(onnxInputName);
+ if (allDimensionSizesAreKnown(onnxType)) {
+ continue;
+ }
+
+ Optional<TensorType> vespaType = resolveInputType(onnxInputName, context);
+ if (vespaType.isEmpty()) {
+ return TensorType.empty;
+ }
+
+ var onnxDimensions = onnxType.getTensorType().getShape().getDimList();
+ var vespaDimensions = vespaType.get().dimensions();
+ if (vespaDimensions.size() != onnxDimensions.size()) {
+ return TensorType.empty;
+ }
+
+ for (int i = 0; i < vespaDimensions.size(); ++i) {
+ if (vespaDimensions.get(i).size().isEmpty()) {
+ continue;
+ }
+ Long size = vespaDimensions.get(i).size().get();
+
+ // Handle dimensions with size -1 - typically batch dimensions
+ if (onnxDimensions.get(i).getDimValue() == -1) {
+ if (unboundSize != 0 && unboundSize != size) {
+ throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " +
+ "for type '" + onnxOutputType + "' in ONNX model '" + name + "'");
+ }
+ unboundSize = size;
+
+ // Handle dimensions with symbolic names
+ } else if (onnxDimensions.get(i).hasDimParam()) {
+ String symbolicName = onnxDimensions.get(i).getDimParam();
+ if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
+ throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" +
+ symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'");
+ }
+ symbolicSizes.put(symbolicName, size);
+ }
+ }
+ }
+ return typeFrom(onnxOutputType, symbolicSizes, unboundSize);
+ }
+
+ private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) {
+ String source = inputMap.get(onnxInputName);
+ if (source != null) {
+ // Source is either a simple reference (query/attribute/constant)...
+ Optional<Reference> reference = Reference.simple(source);
+ if (reference.isPresent()) {
+ return Optional.of(context.getType(reference.get()));
+ }
+ // ... or a function
+ ExpressionFunction func = context.getFunction(source);
+ if (func != null) {
+ return Optional.of(func.getBody().type(context));
+ }
+ }
+ return Optional.empty(); // if this context does not contain this input
+ }
+
+ private static TensorType typeFrom(Onnx.TypeProto type) {
+ return typeFrom(type, null, 0);
+ }
+
+ private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes, long unboundSize) {
+ String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
+ Onnx.TensorShapeProto shape = type.getTensorType().getShape();
+ TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType()));
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
+ long onnxDimensionSize = onnxDimension.getDimValue();
+ if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) {
+ onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam());
+ }
+ if (onnxDimensionSize == 0 && symbolicSizes != null) {
+ // This is for the case where all symbolic dimensions have
+ // different names, but can be resolved to a single dimension size.
+ Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values());
+ if (unknownSizes.size() == 1) {
+ onnxDimensionSize = unknownSizes.iterator().next();
+ }
+ }
+ if (onnxDimensionSize < 0) {
+ onnxDimensionSize = unboundSize;
+ }
+ if (onnxDimensionSize <= 0) {
+ throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " +
+ "ONNX type: " + type + " to Vespa tensor type.");
+ }
+ builder.indexed(dimensionName, onnxDimensionSize);
+ }
+ return builder.build();
+ }
- private OnnxNameMapping(String onnxName, String vespaName) {
- this.onnxName = onnxName;
- this.vespaName = vespaName;
+ private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.FLOAT;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
+ case UINT32: return TensorType.Value.FLOAT;
+ case UINT64: return TensorType.Value.FLOAT;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
}
- public String getOnnxName() { return onnxName; }
- public String getVespaName() { return vespaName; }
- public void setVespaName(String vespaName) { this.vespaName = vespaName; }
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index d309f48d6df..96c043bdb34 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -18,6 +18,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
@@ -158,6 +159,10 @@ public class RankProfile implements Cloneable {
return search != null ? search.rankingConstants() : model.rankingConstants();
}
+ private Map<String, OnnxModel> onnxModels() {
+ return search != null ? search.onnxModels().asMap() : Collections.emptyMap();
+ }
+
private Stream<ImmutableSDField> allFields() {
if (search == null) return Stream.empty();
if (allFieldsList == null) {
@@ -821,6 +826,20 @@ public class RankProfile implements Cloneable {
}
}
+ // Add output types for ONNX models
+ for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) {
+ String modelName = entry.getKey();
+ OnnxModel model = entry.getValue();
+ Arguments args = new Arguments(new ReferenceNode(modelName));
+
+ TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context);
+ context.setType(new Reference("onnxModel", args, null), defaultOutputType);
+
+ for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) {
+ TensorType type = model.getTensorType(mapping.getKey(), context);
+ context.setType(new Reference("onnxModel", args, mapping.getValue()), type);
+ }
+ }
return context;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 84442fedc48..22a32c8fd65 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
modelBuilder.name(model.getName());
modelBuilder.fileref(model.getFileReference());
- model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName())));
- model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName())));
+ model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
+ model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
builder.model(modelBuilder);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 87eaaf0387a..56a5d539906 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<String> functionNames = rankProfile.getFunctions().keySet();
if (functionNames.isEmpty()) return;
for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) {
- for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) {
- String source = mapping.getVespaName();
+ for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) {
+ String source = mapping.getValue();
if (functionNames.contains(source)) {
- mapping.setVespaName("rankingExpression(" + source + ")");
+ onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")");
}
}
}
@@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>();
for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) {
ReferenceNode referenceNode = i.next();
- ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch());
+ ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile);
if (referenceNode != replacedNode) {
replacedSummaryFeatures.add(replacedNode);
i.remove();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index ec517768ea9..d23a8376e7a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- if ( ! feature.getName().equals("onnx")) return feature;
+ if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature;
try {
FeatureArguments arguments = asFeatureArguments(feature.getArguments());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index e1ad003e5bd..69cdae10e47 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -1,20 +1,36 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
+import com.yahoo.path.Path;
import com.yahoo.searchdefinition.ImmutableSearch;
import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.vespa.model.ml.ConvertedModel;
+import com.yahoo.vespa.model.ml.FeatureArguments;
+import com.yahoo.vespa.model.ml.ModelName;
import java.util.List;
/**
- * Transforms instances of the onnxModel ranking feature and generates
- * ONNX configuration if necessary.
+ * Transforms ONNX model features of the forms:
+ *
+ * onnxModel(config_name)
+ * onnxModel(config_name).output
+ * onnxModel("path/to/model")
+ * onnxModel("path/to/model").output
+ * onnxModel("path/to/model", "path/to/output")
+ * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused
+ *
+ * To the format expected by the backend:
+ *
+ * onnxModel(config_name).output
*
* @author lesters
*/
@@ -33,85 +49,92 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
if (context.rankProfile() == null) return feature;
if (context.rankProfile().getSearch() == null) return feature;
- return transformFeature(feature, context.rankProfile().getSearch());
+ return transformFeature(feature, context.rankProfile());
}
- public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) {
- if (!feature.getName().equals("onnxModel")) return feature;
+ public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) {
+ ImmutableSearch search = rankProfile.getSearch();
+ final String featureName = feature.getName();
+ if ( ! featureName.equals("onnxModel")) return feature;
Arguments arguments = feature.getArguments();
if (arguments.isEmpty())
- throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " +
- "onnx-model config or a ONNX file.");
- if (arguments.expressions().size() > 2)
- throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
-
- // Validation that the file actually exists is handled when the file is added to file distribution.
- // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator.
-
- String modelConfigName;
- OnnxModel onnxModel;
- if (arguments.expressions().get(0) instanceof ReferenceNode) {
- modelConfigName = arguments.expressions().get(0).toString();
- onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found");
- }
- } else if (arguments.expressions().get(0) instanceof ConstantNode) {
+ throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " +
+ "onnx-model config or an ONNX file.");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments.");
+
+ // Check that the model configuration "onnx-model" exists. If not defined, it should have been added
+ // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find
+ // the actual ONNX file, which can happen if we are restarting or upgrading an application using an
+ // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store.
+
+ String modelConfigName = getModelConfigName(feature.reference());
+ OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
String path = asString(arguments.expressions().get(0));
- modelConfigName = asValidIdentifier(path);
- onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- onnxModel = new OnnxModel(modelConfigName, path);
- search.onnxModels().add(onnxModel);
- }
- } else {
- throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'");
+ ModelName modelName = new ModelName(null, Path.fromString(path), true);
+ ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile);
+ FeatureArguments featureArguments = new FeatureArguments(arguments);
+ return convertedModel.expression(featureArguments, null);
}
- String output = null;
- if (feature.getOutput() != null) {
- output = feature.getOutput();
- if ( ! hasOutputMapping(onnxModel, output)) {
- onnxModel.addOutputNameMapping(output, output);
+ String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput());
+ String output = getModelOutput(feature.reference(), defaultOutput);
+ if (! onnxModel.getOutputMap().containsValue(output)) {
+ throw new IllegalArgumentException(featureName + " argument '" + output +
+ "' output not found in model '" + onnxModel.getFileName() + "'");
+ }
+ return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output);
+ }
+
+ public static String getModelConfigName(Reference reference) {
+ if (reference.arguments().size() > 0) {
+ ExpressionNode expr = reference.arguments().expressions().get(0);
+ if (expr instanceof ReferenceNode) { // refers to onnx-model config
+ return expr.toString();
}
- } else if (arguments.expressions().size() > 1) {
- String name = asString(arguments.expressions().get(1));
- output = asValidIdentifier(name);
- if ( ! hasOutputMapping(onnxModel, output)) {
- onnxModel.addOutputNameMapping(name, output);
+ if (expr instanceof ConstantNode) { // refers to an file path
+ return asValidIdentifier(expr);
}
}
+ return null;
+ }
- // Replace feature with name of config
- ExpressionNode argument = new ReferenceNode(modelConfigName);
- return new ReferenceNode("onnxModel", List.of(argument), output);
-
+ public static String getModelOutput(Reference reference, String defaultOutput) {
+ if (reference.output() != null) {
+ return reference.output();
+ } else if (reference.arguments().expressions().size() == 2) {
+ return asValidIdentifier(reference.arguments().expressions().get(1));
+ } else if (reference.arguments().expressions().size() > 2) {
+ return asValidIdentifier(reference.arguments().expressions().get(2));
+ }
+ return defaultOutput;
}
- private static boolean hasOutputMapping(OnnxModel onnxModel, String as) {
- return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as));
+ public static String stripQuotes(String s) {
+ if (isNotQuoteSign(s.codePointAt(0))) return s;
+ if (isNotQuoteSign(s.codePointAt(s.length() - 1)))
+ throw new IllegalArgumentException("argument [" + s + "] is missing end quote");
+ return s.substring(1, s.length()-1);
}
- private static String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
+ public static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
}
- private static String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
+ private static String asValidIdentifier(ExpressionNode node) {
+ return asValidIdentifier(asString(node));
}
- private static boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
+ private static boolean isNotQuoteSign(int c) {
+ return c != '\'' && c != '"';
}
- private static String asValidIdentifier(String str) {
- return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ public static String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
new file mode 100644
index 00000000000..afba88c135d
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
@@ -0,0 +1,97 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
+
+import java.util.Map;
+
+/**
+ * Processes ONNX ranking features of the form:
+ *
+ * onnx("files/model.onnx", "path/to/output:1")
+ *
+ * And generates an "onnx-model" configuration as if it was defined in the schema:
+ *
+ * onnx-model files_model_onnx {
+ * file: "files/model.onnx"
+ * }
+ *
+ * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be
+ * processed after this.
+ *
+ * @author lesters
+ */
+public class OnnxModelConfigGenerator extends Processor {
+
+ public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
+ super(search, deployLogger, rankProfileRegistry, queryProfiles);
+ }
+
+ @Override
+ public void process(boolean validate, boolean documentsOnly) {
+ if (documentsOnly) return;
+ for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) {
+ if (profile.getFirstPhaseRanking() != null) {
+ process(profile.getFirstPhaseRanking().getRoot());
+ }
+ if (profile.getSecondPhaseRanking() != null) {
+ process(profile.getSecondPhaseRanking().getRoot());
+ }
+ for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
+ process(function.getValue().function().getBody().getRoot());
+ }
+ for (ReferenceNode feature : profile.getSummaryFeatures()) {
+ process(feature);
+ }
+ }
+ }
+
+ private void process(ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ process((ReferenceNode)node);
+ } else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode) node).children()) {
+ process(child);
+ }
+ }
+ }
+
+ private void process(ReferenceNode feature) {
+ if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) {
+ if (feature.getArguments().size() > 0) {
+ if (feature.getArguments().expressions().get(0) instanceof ConstantNode) {
+ ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0);
+ String path = OnnxModelTransformer.stripQuotes(node.sourceString());
+ String modelConfigName = OnnxModelTransformer.asValidIdentifier(path);
+
+ // Only add the configuration if the model can actually be found.
+ if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
+ path = ApplicationPackage.MODELS_DIR.append(path).toString();
+ if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
+ return;
+ }
+ }
+
+ OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ onnxModel = new OnnxModel(modelConfigName, path);
+ search.onnxModels().add(onnxModel);
+ }
+ }
+ }
+ }
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
new file mode 100644
index 00000000000..bead2e7e7c9
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
@@ -0,0 +1,154 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.component.Version;
+import com.yahoo.config.FileReference;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.config.application.api.FileRegistry;
+import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
+import com.yahoo.vespa.defaults.Defaults;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
+import onnx.Onnx;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Paths;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Processes every "onnx-model" element in the schema. Parses the model file,
+ * adds missing input and output mappings (assigning default names), and
+ * adds tensor types to all model inputs and outputs.
+ *
+ * Must be processed before RankingExpressingTypeResolver.
+ *
+ * @author lesters
+ */
+public class OnnxModelTypeResolver extends Processor {
+
+ public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
+ super(search, deployLogger, rankProfileRegistry, queryProfiles);
+ }
+
+ @Override
+ public void process(boolean validate, boolean documentsOnly) {
+ if (documentsOnly) return;
+
+ for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) {
+ OnnxModel modelConfig = entry.getValue();
+ try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+
+ // Model inputs - if not defined, assumes a function is provided with a valid name
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
+ String onnxInputName = valueInfo.getName();
+ String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName);
+ modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false);
+ modelConfig.addInputType(onnxInputName, valueInfo.getType());
+ }
+
+ // Model outputs
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
+ String onnxOutputName = valueInfo.getName();
+ String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName);
+ modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false);
+ modelConfig.addOutputType(onnxOutputName, valueInfo.getType());
+ }
+
+ // Set the first output as default
+ if ( ! model.getGraph().getOutputList().isEmpty()) {
+ modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName());
+ }
+
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Unable to parse ONNX model", e);
+ }
+ }
+ }
+
+ static boolean modelFileExists(String path, ApplicationPackage app) {
+ Path pathInApplicationPackage = Path.fromString(path);
+ if (getFile(pathInApplicationPackage, app).exists()) {
+ return true;
+ }
+ if (getFileReference(pathInApplicationPackage, app).isPresent()) {
+ return true;
+ }
+ return false;
+ }
+
+ private InputStream openModelFile(Path path) throws FileNotFoundException {
+ ApplicationFile file;
+ Optional<FileReference> reference;
+ Path modelsPath = ApplicationPackage.MODELS_DIR.append(path);
+
+ if ((file = getFile(path)).exists()) {
+ return file.createInputStream();
+ }
+ if ((file = getFile(modelsPath)).exists()) {
+ return file.createInputStream();
+ }
+ if ((reference = getFileReference(path)).isPresent()) {
+ return openFromFileRepository(path, reference.get());
+ }
+ if ((reference = getFileReference(modelsPath)).isPresent()) {
+ return openFromFileRepository(modelsPath, reference.get());
+ }
+
+ throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " +
+ "in application package or file repository.");
+ }
+
+ private ApplicationFile getFile(Path path) {
+ return getFile(path, search.applicationPackage());
+ }
+
+ private static ApplicationFile getFile(Path path, ApplicationPackage app) {
+ return app.getFile(path);
+ }
+
+ private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException {
+ return new FileInputStream(new File(getFileRepositoryPath(path, reference.value())));
+ }
+
+ public static String getFileRepositoryPath(Path path, String fileReference) {
+ ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
+ String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
+ return Paths.get(fileRefDir, fileReference, path.getName()).toString();
+ }
+
+ private Optional<FileReference> getFileReference(Path path) {
+ return getFileReference(path, search.applicationPackage());
+ }
+
+ private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) {
+ Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app);
+ if (fileRegistry.isPresent()) {
+ for (FileRegistry.Entry file : fileRegistry.get().export()) {
+ if (file.relativePath.equals(path.toString())) {
+ return Optional.of(file.reference);
+ }
+ }
+ }
+ return Optional.empty();
+ }
+
+ private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) {
+ if (app == null) return Optional.empty();
+ Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo);
+ return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get()));
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
index e8594c2a87f..1a3ef9e54b4 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
@@ -74,6 +74,8 @@ public class Processing {
ReferenceFieldsProcessor::new,
FastAccessValidator::new,
ReservedFunctionNames::new,
+ OnnxModelConfigGenerator::new,
+ OnnxModelTypeResolver::new,
RankingExpressionTypeResolver::new,
// These should be last:
IndexingValidation::new,
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
index c6c7969e466..d5c5183b01f 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
@@ -1,20 +1,19 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation;
-import com.yahoo.cloud.config.ConfigserverConfig;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.IOUtils;
import com.yahoo.log.InvalidLogFormatException;
import java.util.logging.Level;
import com.yahoo.log.LogMessage;
import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver;
import com.yahoo.yolean.Exceptions;
import com.yahoo.system.ProcessExecuter;
import com.yahoo.text.StringUtilities;
import com.yahoo.vespa.config.search.AttributesConfig;
import com.yahoo.collections.Pair;
import com.yahoo.config.ConfigInstance;
-import com.yahoo.vespa.defaults.Defaults;
import com.yahoo.vespa.config.search.ImportedFieldsConfig;
import com.yahoo.vespa.config.search.IndexschemaConfig;
import com.yahoo.vespa.config.search.RankProfilesConfig;
@@ -31,7 +30,6 @@ import com.yahoo.vespa.model.search.SearchCluster;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
-import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.logging.Logger;
@@ -152,12 +150,9 @@ public class RankSetupValidator extends Validator {
// Assist verify-ranksetup in finding the actual ONNX model files
Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap();
if (models.values().size() > 0) {
- ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
- String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
List<String> config = new ArrayList<>(models.values().size() * 2);
for (OnnxModel model : models.values()) {
- String modelFilename = Paths.get(model.getFileName()).getFileName().toString();
- String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString();
+ String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference());
config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference()));
config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath));
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 943fcbf6c1d..5ee6ed02e61 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -150,7 +150,7 @@ public class ConvertedModel {
*/
public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
ExpressionFunction expression = selectExpression(arguments);
- if (sourceModel.isPresent()) // we should verify
+ if (sourceModel.isPresent() && context != null) // we should verify
verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
return expression.getBody().getRoot();
}
diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto
new file mode 100644
index 00000000000..dc6542867e0
--- /dev/null
+++ b/config-model/src/main/protobuf/onnx.proto
@@ -0,0 +1,464 @@
+//
+// WARNING: This file is automatically generated! Please edit onnx.in.proto.
+//
+
+
+// Copyright (c) Facebook Inc. and Microsoft Corporation.
+// Licensed under the MIT license.
+
+syntax = "proto2";
+
+package onnx;
+
+// Overview
+//
+// ONNX is an open specification that is comprised of the following components:
+//
+// 1) A definition of an extensible computation graph model.
+// 2) Definitions of standard data types.
+// 3) Definitions of built-in operators.
+//
+// This document describes the syntax of models and their computation graphs,
+// as well as the standard data types. Together, they are referred to as the ONNX
+// Intermediate Representation, or 'IR' for short.
+//
+// The normative semantic specification of the ONNX IR is found in docs/IR.md.
+// Definitions of the built-in neural network operators may be found in docs/Operators.md.
+
+// Notes
+//
+// Release
+//
+// We are still in the very early stage of defining ONNX. The current
+// version of ONNX is a starting point. While we are actively working
+// towards a complete spec, we would like to get the community involved
+// by sharing our working version of ONNX.
+//
+// Protobuf compatibility
+//
+// To simplify framework compatibility, ONNX is defined using the subset of protobuf
+// that is compatible with both protobuf v2 and v3. This means that we do not use any
+// protobuf features that are only available in one of the two versions.
+//
+// Here are the most notable contortions we have to carry out to work around
+// these limitations:
+//
+// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
+// of key-value pairs, where order does not matter and duplicates
+// are not allowed.
+
+
+// Versioning
+//
+// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
+//
+// To be compatible with both proto2 and proto3, we will use a version number
+// that is not defined by the default value but an explicit enum number.
+enum Version {
+ // proto3 requires the first enum value to be zero.
+ // We add this just to appease the compiler.
+ _START_VERSION = 0;
+ // The version field is always serialized and we will use it to store the
+ // version that the graph is generated from. This helps us set up version
+ // control. We should use version as
+ // xx(major) - xx(minor) - xxxx(bugfix)
+ // and we are starting with 0x00000001 (0.0.1), which was the
+ // version we published on Oct 10, 2017.
+ IR_VERSION_2017_10_10 = 0x00000001;
+
+ // IR_VERSION 0.0.2 published on Oct 30, 2017
+ // - Added type discriminator to AttributeProto to support proto3 users
+ IR_VERSION_2017_10_30 = 0x00000002;
+
+ // IR VERSION 0.0.3 published on Nov 3, 2017
+ // - For operator versioning:
+ // - Added new message OperatorSetIdProto
+ // - Added opset_import in ModelProto
+ // - For vendor extensions, added domain in NodeProto
+ IR_VERSION = 0x00000003;
+}
+
+// Attributes
+//
+// A named attribute containing either singular float, integer, string, graph,
+// and tensor values, or repeated float, integer, string, graph, and tensor values.
+// An AttributeProto MUST contain the name field, and *only one* of the
+// following content fields, effectively enforcing a C/C++ union equivalent.
+message AttributeProto {
+
+ // Note: this enum is structurally identical to the OpSchema::AttrType
+ // enum defined in schema.h. If you rev one, you likely need to rev the other.
+ enum AttributeType {
+ UNDEFINED = 0;
+ FLOAT = 1;
+ INT = 2;
+ STRING = 3;
+ TENSOR = 4;
+ GRAPH = 5;
+
+ FLOATS = 6;
+ INTS = 7;
+ STRINGS = 8;
+ TENSORS = 9;
+ GRAPHS = 10;
+ }
+
+ // The name field MUST be present for this version of the IR.
+ optional string name = 1; // namespace Attribute
+
+ // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
+ // In this case, this AttributeProto does not contain data, and it's a reference of attribute
+ // in parent scope.
+ // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
+ optional string ref_attr_name = 21;
+
+ // A human-readable documentation for this attribute. Markdown is allowed.
+ optional string doc_string = 13;
+
+ // The type field MUST be present for this version of the IR.
+ // For 0.0.1 versions of the IR, this field was not defined, and
+ // implementations needed to use has_field hueristics to determine
+ // which value field was in use. For IR_VERSION 0.0.2 or later, this
+ // field MUST be set and match the f|i|s|t|... field in use. This
+ // change was made to accomodate proto3 implementations.
+ optional AttributeType type = 20; // discriminator that indicates which field below is in use
+
+ // Exactly ONE of the following fields must be present for this version of the IR
+ optional float f = 2; // float
+ optional int64 i = 3; // int
+ optional bytes s = 4; // UTF-8 string
+ optional TensorProto t = 5; // tensor value
+ optional GraphProto g = 6; // graph
+ // Do not use field below, it's deprecated.
+ // optional ValueProto v = 12; // value - subsumes everything but graph
+
+ repeated float floats = 7; // list of floats
+ repeated int64 ints = 8; // list of ints
+ repeated bytes strings = 9; // list of UTF-8 strings
+ repeated TensorProto tensors = 10; // list of tensors
+ repeated GraphProto graphs = 11; // list of graph
+}
+
+// Defines information on value, including the name, the type, and
+// the shape of the value.
+message ValueInfoProto {
+ // This field MUST be present in this version of the IR.
+ optional string name = 1; // namespace Value
+ // This field MUST be present in this version of the IR.
+ optional TypeProto type = 2;
+ // A human-readable documentation for this value. Markdown is allowed.
+ optional string doc_string = 3;
+}
+
+// Nodes
+//
+// Computation graphs are made up of a DAG of nodes, which represent what is
+// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
+//
+// For example, it can be a node of type "Conv" that takes in an image, a filter
+// tensor and a bias tensor, and produces the convolved output.
+message NodeProto {
+ repeated string input = 1; // namespace Value
+ repeated string output = 2; // namespace Value
+
+ // An optional identifier for this node in a graph.
+ // This field MAY be absent in ths version of the IR.
+ optional string name = 3; // namespace Node
+
+ // The symbolic identifier of the Operator to execute.
+ optional string op_type = 4; // namespace Operator
+ // The domain of the OperatorSet that specifies the operator named by op_type.
+ optional string domain = 7; // namespace Domain
+
+ // Additional named attributes.
+ repeated AttributeProto attribute = 5;
+
+ // A human-readable documentation for this node. Markdown is allowed.
+ optional string doc_string = 6;
+}
+
+// Models
+//
+// ModelProto is a top-level file/container format for bundling a ML model and
+// associating its computation graph with metadata.
+//
+// The semantics of the model are described by the associated GraphProto.
+message ModelProto {
+ // The version of the IR this model targets. See Version enum above.
+ // This field MUST be present.
+ optional int64 ir_version = 1;
+
+ // The OperatorSets this model relies on.
+ // All ModelProtos MUST have at least one entry that
+ // specifies which version of the ONNX OperatorSet is
+ // being imported.
+ //
+ // All nodes in the ModelProto's graph will bind against the operator
+ // with the same-domain/same-op_type operator with the HIGHEST version
+ // in the referenced operator sets.
+ repeated OperatorSetIdProto opset_import = 8;
+
+ // The name of the framework or tool used to generate this model.
+ // This field SHOULD be present to indicate which implementation/tool/framework
+ // emitted the model.
+ optional string producer_name = 2;
+
+ // The version of the framework or tool used to generate this model.
+ // This field SHOULD be present to indicate which implementation/tool/framework
+ // emitted the model.
+ optional string producer_version = 3;
+
+ // Domain name of the model.
+ // We use reverse domain names as name space indicators. For example:
+ // `com.facebook.fair` or `com.microsoft.cognitiveservices`
+ //
+ // Together with `model_version` and GraphProto.name, this forms the unique identity of
+ // the graph.
+ optional string domain = 4;
+
+ // The version of the graph encoded. See Version enum below.
+ optional int64 model_version = 5;
+
+ // A human-readable documentation for this model. Markdown is allowed.
+ optional string doc_string = 6;
+
+ // The parameterized graph that is evaluated to execute the model.
+ optional GraphProto graph = 7;
+
+ // Named metadata values; keys should be distinct.
+ repeated StringStringEntryProto metadata_props = 14;
+};
+
+// StringStringEntryProto follows the pattern for cross-proto-version maps.
+// See https://developers.google.com/protocol-buffers/docs/proto3#maps
+message StringStringEntryProto {
+ optional string key = 1;
+ optional string value= 2;
+};
+
+// Graphs
+//
+// A graph defines the computational logic of a model and is comprised of a parameterized
+// list of nodes that form a directed acyclic graph based on their inputs and outputs.
+// This is the equivalent of the "network" or "graph" in many deep learning
+// frameworks.
+message GraphProto {
+ // The nodes in the graph, sorted topologically.
+ repeated NodeProto node = 1;
+
+ // The name of the graph.
+ optional string name = 2; // namespace Graph
+
+ // A list of named tensor values, used to specify constant inputs of the graph.
+ // Each TensorProto entry must have a distinct name (within the list) that
+ // also appears in the input list.
+ repeated TensorProto initializer = 5;
+
+ // A human-readable documentation for this graph. Markdown is allowed.
+ optional string doc_string = 10;
+
+ // The inputs and outputs of the graph.
+ repeated ValueInfoProto input = 11;
+ repeated ValueInfoProto output = 12;
+
+ // Information for the values in the graph. The ValueInfoProto.name's
+ // must be distinct. It is optional for a value to appear in value_info list.
+ repeated ValueInfoProto value_info = 13;
+
+ // DO NOT USE the following fields, they were deprecated from earlier versions.
+ // repeated string input = 3;
+ // repeated string output = 4;
+ // optional int64 ir_version = 6;
+ // optional int64 producer_version = 7;
+ // optional string producer_tag = 8;
+ // optional string domain = 9;
+}
+
+// Tensors
+//
+// A serialized tensor value.
+message TensorProto {
+ enum DataType {
+ UNDEFINED = 0;
+ // Basic types.
+ FLOAT = 1; // float
+ UINT8 = 2; // uint8_t
+ INT8 = 3; // int8_t
+ UINT16 = 4; // uint16_t
+ INT16 = 5; // int16_t
+ INT32 = 6; // int32_t
+ INT64 = 7; // int64_t
+ STRING = 8; // string
+ BOOL = 9; // bool
+
+ // Advanced types
+ FLOAT16 = 10;
+ DOUBLE = 11;
+ UINT32 = 12;
+ UINT64 = 13;
+ COMPLEX64 = 14; // complex with float32 real and imaginary components
+ COMPLEX128 = 15; // complex with float64 real and imaginary components
+ // Future extensions go here.
+ }
+
+ // The shape of the tensor.
+ repeated int64 dims = 1;
+
+ // The data type of the tensor.
+ optional DataType data_type = 2;
+
+ // For very large tensors, we may want to store them in chunks, in which
+ // case the following fields will specify the segment that is stored in
+ // the current TensorProto.
+ message Segment {
+ optional int64 begin = 1;
+ optional int64 end = 2;
+ }
+ optional Segment segment = 3;
+
+ // Tensor content must be organized in row-major order.
+ //
+ // Depending on the data_type field, exactly one of the fields below with
+ // name ending in _data is used to store the elements of the tensor.
+
+ // For float and complex64 values
+ // Complex64 tensors are encoded as a single array of floats,
+ // with the real components appearing in odd numbered positions,
+ // and the corresponding imaginary component apparing in the
+ // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
+ // is encoded as [1.0, 2.0 ,3.0 ,4.0]
+ // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
+ repeated float float_data = 4 [packed = true];
+
+ // For int32, uint8, int8, uint16, int16, bool, and float16 values
+ // float16 values must be bit-wise converted to an uint16_t prior
+ // to writing to the buffer.
+ // When this field is present, the data_type field MUST be
+ // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
+ repeated int32 int32_data = 5 [packed = true];
+
+ // For strings.
+ // Each element of string_data is a UTF-8 encoded Unicode
+ // string. No trailing null, no leading BOM. The protobuf "string"
+ // scalar type is not used to match ML community conventions.
+ // When this field is present, the data_type field MUST be STRING
+ repeated bytes string_data = 6;
+
+ // For int64.
+ // When this field is present, the data_type field MUST be INT64
+ repeated int64 int64_data = 7 [packed = true];
+
+ // Optionally, a name for the tensor.
+ optional string name = 8; // namespace Value
+
+ // A human-readable documentation for this tensor. Markdown is allowed.
+ optional string doc_string = 12;
+
+ // Serializations can either use one of the fields above, or use this
+ // raw bytes field. The only exception is the string case, where one is
+ // required to store the content in the repeated bytes string_data field.
+ //
+ // When this raw_data field is used to store tensor value, elements MUST
+ // be stored in as fixed-width, little-endian order.
+ // Floating-point data types MUST be stored in IEEE 754 format.
+ // Complex64 elements must be written as two consecutive FLOAT values, real component first.
+ // Complex128 elements must be written as two consecutive DOUBLE values, real component first.
+ // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
+ //
+ // Note: the advantage of specific field rather than the raw_data field is
+ // that in some cases (e.g. int data), protobuf does a better packing via
+ // variable length storage, and may lead to smaller binary footprint.
+ // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
+ optional bytes raw_data = 9;
+
+ // For double
+ // Complex64 tensors are encoded as a single array of doubles,
+ // with the real components appearing in odd numbered positions,
+ // and the corresponding imaginary component apparing in the
+ // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
+ // is encoded as [1.0, 2.0 ,3.0 ,4.0]
+ // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
+ repeated double double_data = 10 [packed = true];
+
+ // For uint64 and uint32 values
+ // When this field is present, the data_type field MUST be
+ // UINT32 or UINT64
+ repeated uint64 uint64_data = 11 [packed = true];
+}
+
+// Defines a tensor shape. A dimension can be either an integer value
+// or a symbolic variable. A symbolic variable represents an unknown
+// dimension.
+message TensorShapeProto {
+ message Dimension {
+ oneof value {
+ int64 dim_value = 1;
+ string dim_param = 2; // namespace Shape
+ };
+ // Standard denotation can optionally be used to denote tensor
+ // dimensions with standard semantic descriptions to ensure
+ // that operations are applied to the correct axis of a tensor.
+ optional string denotation = 3;
+ };
+ repeated Dimension dim = 1;
+}
+
+// A set of pre-defined constants to be used as values for
+// the standard denotation field in TensorShapeProto.Dimension
+// for semantic description of the tensor dimension.
+message DenotationConstProto {
+ // Describe a batch number dimension.
+ optional string DATA_BATCH = 1 [default = "DATA_BATCH"];
+ // Describe a channel dimension.
+ optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"];
+ // Describe a time dimension.
+ optional string DATA_TIME = 3 [default = "DATA_TIME"];
+ // Describe a feature dimension. This is typically a feature
+ // dimension in RNN and/or spatial dimension in CNN.
+ optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"];
+ // Describe a filter in-channel dimension. This is the dimension
+ // that is identical (in size) to the channel dimension of the input
+ // image feature maps.
+ optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"];
+ // Describe a filter out channel dimension. This is the dimension
+ // that is identical (int size) to the channel dimension of the output
+ // image feature maps.
+ optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"];
+ // Describe a filter spatial dimension.
+ optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"];
+}
+
+// Types
+//
+// The standard ONNX data types.
+message TypeProto {
+
+ message Tensor {
+ // This field MUST NOT have the value of UNDEFINED
+ // This field MUST be present for this version of the IR.
+ optional TensorProto.DataType elem_type = 1;
+ optional TensorShapeProto shape = 2;
+ }
+
+
+ oneof value {
+ // The type of a tensor.
+ Tensor tensor_type = 1;
+
+ }
+}
+
+// Operator Sets
+//
+// OperatorSets are uniquely identified by a (domain, opset_version) pair.
+message OperatorSetIdProto {
+ // The domain of the operator set being identified.
+ // The empty string ("") or absence of this field implies the operator
+ // set that is defined as part of the ONNX specification.
+ // This field MUST be present in this version of the IR when referring to any other operator set.
+ optional string domain = 1;
+
+ // The version of the operator set being identified.
+ // This field MUST be present in this version of the IR.
+ optional int64 version = 2;
+} \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
index e9575af6010..cc73f2daff5 100644
--- a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
+++ b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
@@ -18,7 +18,7 @@ search test {
}
function mnist_softmax_onnx() {
- expression: onnx("mnist_softmax")
+ expression: onnx_vespa("mnist_softmax")
}
function my_xgboost() {
diff --git a/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py
new file mode 100755
index 00000000000..55df3a557e9
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py
@@ -0,0 +1,12 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "sequence"])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "sequence"])
+
+nodes = [helper.make_node('Identity', ['input'], ['output'])]
+graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT])
+model_def = helper.make_model(graph_def, producer_name='create_dynamic_model.py')
+onnx.save(model_def, 'dynamic_model.onnx')
diff --git a/config-model/src/test/integration/onnx-model/files/create_model.py b/config-model/src/test/integration/onnx-model/files/create_model.py
new file mode 100755
index 00000000000..10ff92c2eda
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/create_model.py
@@ -0,0 +1,37 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('first_input', TensorProto.FLOAT, [2])
+INPUT_2 = helper.make_tensor_value_info('second/input:0', TensorProto.FLOAT, [2])
+INPUT_3 = helper.make_tensor_value_info('third_input', TensorProto.FLOAT, [2])
+OUTPUT_1 = helper.make_tensor_value_info('path/to/output:0', TensorProto.FLOAT, [2])
+OUTPUT_2 = helper.make_tensor_value_info('path/to/output:1', TensorProto.FLOAT, [2])
+OUTPUT_3 = helper.make_tensor_value_info('path/to/output:2', TensorProto.FLOAT, [2])
+
+nodes = [
+ helper.make_node(
+ 'Add',
+ ['first_input', 'second/input:0'],
+ ['path/to/output:0'],
+ ),
+ helper.make_node(
+ 'Add',
+ ['third_input', 'second/input:0'],
+ ['path/to/output:1']
+ ),
+ helper.make_node(
+ 'Add',
+ ['path/to/output:0', 'path/to/output:1'],
+ ['path/to/output:2']
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'simple_scoring',
+ [INPUT_1, INPUT_2, INPUT_3],
+ [OUTPUT_1, OUTPUT_2, OUTPUT_3]
+)
+model_def = helper.make_model(graph_def, producer_name='create_model.py')
+onnx.save(model_def, 'model.onnx')
diff --git a/config-model/src/test/integration/onnx-model/files/create_unbound_model.py b/config-model/src/test/integration/onnx-model/files/create_unbound_model.py
new file mode 100755
index 00000000000..abf733ea43f
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/create_unbound_model.py
@@ -0,0 +1,12 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, [-1, 2])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [-1, 2])
+
+nodes = [helper.make_node('Identity', ['input'], ['output'])]
+graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT])
+model_def = helper.make_model(graph_def, producer_name='create_unbound_model.py')
+onnx.save(model_def, 'unbound_model.onnx')
diff --git a/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx
new file mode 100644
index 00000000000..6bbdad2d76e
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx
@@ -0,0 +1,13 @@
+create_dynamic_model.py:x
+
+inputoutput"Identitysimple_scoringZ$
+input
+
+batch
+
+sequenceb%
+output
+
+batch
+
+sequenceB \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/files/model.onnx b/config-model/src/test/integration/onnx-model/files/model.onnx
new file mode 100644
index 00000000000..f3898205c6a
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/model.onnx
@@ -0,0 +1,34 @@
+create_model.py:í
+4
+ first_input
+second/input:0path/to/output:0"Add
+4
+ third_input
+second/input:0path/to/output:1"Add
+;
+path/to/output:0
+path/to/output:1path/to/output:2"Addsimple_scoringZ
+ first_input
+
+
+Z
+second/input:0
+
+
+Z
+ third_input
+
+
+b
+path/to/output:0
+
+
+b
+path/to/output:1
+
+
+b
+path/to/output:2
+
+
+B \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/files/summary_model.onnx b/config-model/src/test/integration/onnx-model/files/summary_model.onnx
new file mode 100644
index 00000000000..f3898205c6a
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/summary_model.onnx
@@ -0,0 +1,34 @@
+create_model.py:í
+4
+ first_input
+second/input:0path/to/output:0"Add
+4
+ third_input
+second/input:0path/to/output:1"Add
+;
+path/to/output:0
+path/to/output:1path/to/output:2"Addsimple_scoringZ
+ first_input
+
+
+Z
+second/input:0
+
+
+Z
+ third_input
+
+
+b
+path/to/output:0
+
+
+b
+path/to/output:1
+
+
+b
+path/to/output:2
+
+
+B \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/files/unbound_model.onnx b/config-model/src/test/integration/onnx-model/files/unbound_model.onnx
new file mode 100644
index 00000000000..155b3125256
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/unbound_model.onnx
@@ -0,0 +1,11 @@
+create_unbound_model.py:p
+
+inputoutput"Identitysimple_scoringZ
+input
+
+ ÿÿÿÿÿÿÿÿÿ
+b!
+output
+
+ ÿÿÿÿÿÿÿÿÿ
+B \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
index 0f0fa694e6f..a87222e77ee 100644
--- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
+++ b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
@@ -14,7 +14,7 @@ search test {
}
onnx-model my_model {
- file: files/ranking_model.onnx
+ file: files/model.onnx
input first_input: attribute(document_field)
input "second/input:0": constant(my_constant)
input "third_input": my_function
@@ -22,19 +22,31 @@ search test {
}
onnx-model another_model {
- file: files/ranking_model.onnx
+ file: files/model.onnx
input first_input: attribute(document_field)
input "second/input:0": constant(my_constant)
input "third_input": another_function
output "path/to/output:2": out
}
+ onnx-model dynamic_model {
+ file: files/dynamic_model.onnx
+ input input: my_function
+ output output: my_output
+ }
+
+ onnx-model unbound_model {
+ file: files/unbound_model.onnx
+ input input: my_function
+ output output: my_output
+ }
+
rank-profile test_model_config {
function my_function() {
expression: tensor(d0[2])(1)
}
first-phase {
- expression: onnxModel(my_model).out
+ expression: onnxModel(my_model).out{d0:1}
}
}
@@ -49,7 +61,7 @@ search test {
expression: my_function()
}
first-phase {
- expression: onnxModel("files/ranking_model.onnx", "path/to/output:1")
+ expression: onnxModel("files/model.onnx", "path/to/output:1"){d0:1}
}
}
@@ -62,9 +74,39 @@ search test {
}
summary-features {
onnxModel(another_model).out
- onnxModel("files/ranking_model.onnx", "path/to/output:2")
+ onnxModel("files/summary_model.onnx", "path/to/output:2")
}
+ }
+ rank-profile test_dynamic_model {
+ function my_function() {
+ expression: tensor(d0[1],d1[2])(d1)
+ }
+ first-phase {
+ expression: onnxModel(dynamic_model){d0:0,d1:1}
+ }
}
+ rank-profile test_dynamic_model_2 {
+ function my_function_2() {
+ expression: tensor(d0[1],d1[3])(d1)
+ }
+ function my_function() {
+ expression: my_function_2()
+ }
+ first-phase {
+ expression: onnxModel(dynamic_model){d0:0,d1:2}
+ }
+ }
+
+ rank-profile test_unbound_model {
+ function my_function() {
+ expression: tensor(d0[1],d1[2])(d1)
+ }
+ first-phase {
+ expression: onnxModel(unbound_model){d0:0,d1:1}
+ }
+ }
+
+
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index d9b0c70dfdd..4eb8681c374 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -25,43 +25,68 @@ public class RankingExpressionWithOnnxModelTestCase {
OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder();
((OnnxModelsConfig.Producer) db).getConfig(builder);
OnnxModelsConfig config = new OnnxModelsConfig(builder);
- assertEquals(3, config.model().size());
+ assertEquals(6, config.model().size());
- assertEquals("my_model", config.model(1).name());
+ assertEquals("my_model", config.model(0).name());
+ assertEquals(3, config.model(0).input().size());
+ assertEquals("second/input:0", config.model(0).input(0).name());
+ assertEquals("constant(my_constant)", config.model(0).input(0).source());
+ assertEquals("first_input", config.model(0).input(1).name());
+ assertEquals("attribute(document_field)", config.model(0).input(1).source());
+ assertEquals("third_input", config.model(0).input(2).name());
+ assertEquals("rankingExpression(my_function)", config.model(0).input(2).source());
+ assertEquals(3, config.model(0).output().size());
+ assertEquals("path/to/output:0", config.model(0).output(0).name());
+ assertEquals("out", config.model(0).output(0).as());
+ assertEquals("path/to/output:1", config.model(0).output(1).name());
+ assertEquals("path_to_output_1", config.model(0).output(1).as());
+ assertEquals("path/to/output:2", config.model(0).output(2).name());
+ assertEquals("path_to_output_2", config.model(0).output(2).as());
+
+ assertEquals("files_model_onnx", config.model(1).name());
assertEquals(3, config.model(1).input().size());
- assertEquals("first_input", config.model(1).input(0).name());
- assertEquals("attribute(document_field)", config.model(1).input(0).source());
- assertEquals("second/input:0", config.model(1).input(1).name());
- assertEquals("constant(my_constant)", config.model(1).input(1).source());
- assertEquals("third_input", config.model(1).input(2).name());
- assertEquals("rankingExpression(my_function)", config.model(1).input(2).source());
- assertEquals(1, config.model(1).output().size());
+ assertEquals(3, config.model(1).output().size());
assertEquals("path/to/output:0", config.model(1).output(0).name());
- assertEquals("out", config.model(1).output(0).as());
-
- assertEquals("files_ranking_model_onnx", config.model(0).name());
- assertEquals(0, config.model(0).input().size());
- assertEquals(2, config.model(0).output().size());
- assertEquals("path/to/output:1", config.model(0).output(0).name());
- assertEquals("path_to_output_1", config.model(0).output(0).as());
- assertEquals("path/to/output:2", config.model(0).output(1).name());
- assertEquals("path_to_output_2", config.model(0).output(1).as());
+ assertEquals("path_to_output_0", config.model(1).output(0).as());
+ assertEquals("path/to/output:1", config.model(1).output(1).name());
+ assertEquals("path_to_output_1", config.model(1).output(1).as());
+ assertEquals("path/to/output:2", config.model(1).output(2).name());
+ assertEquals("path_to_output_2", config.model(1).output(2).as());
+ assertEquals("files_model_onnx", config.model(1).name());
assertEquals("another_model", config.model(2).name());
assertEquals("third_input", config.model(2).input(2).name());
assertEquals("rankingExpression(another_function)", config.model(2).input(2).source());
+
+ assertEquals("files_summary_model_onnx", config.model(3).name());
+ assertEquals(3, config.model(3).input().size());
+ assertEquals(3, config.model(3).output().size());
+
+ assertEquals("dynamic_model", config.model(5).name());
+ assertEquals(1, config.model(5).input().size());
+ assertEquals(1, config.model(5).output().size());
+ assertEquals("rankingExpression(my_function)", config.model(5).input(0).source());
+
+ assertEquals("unbound_model", config.model(4).name());
+ assertEquals(1, config.model(4).input().size());
+ assertEquals(1, config.model(4).output().size());
+ assertEquals("rankingExpression(my_function)", config.model(4).input(0).source());
+
+
}
private void assertTransformedFeature(DocumentDatabase db) {
RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
((RankProfilesConfig.Producer) db).getConfig(builder);
RankProfilesConfig config = new RankProfilesConfig(builder);
- assertEquals(5, config.rankprofile().size());
+ assertEquals(8, config.rankprofile().size());
assertEquals("test_model_config", config.rankprofile(2).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name());
assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name());
- assertEquals("onnxModel(my_model).out", config.rankprofile(2).fef().property(2).value());
+ assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name());
+ assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value());
assertEquals("test_generated_model_config", config.rankprofile(3).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name());
@@ -69,16 +94,34 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("rankingExpression(second_input).rankingScript", config.rankprofile(3).fef().property(4).name());
assertEquals("rankingExpression(third_input).rankingScript", config.rankprofile(3).fef().property(6).name());
assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name());
- assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_1", config.rankprofile(3).fef().property(8).value());
+ assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name());
+ assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value());
assertEquals("test_summary_features", config.rankprofile(4).name());
assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name());
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name());
assertEquals("1", config.rankprofile(4).fef().property(3).value());
assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name());
- assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(4).value());
+ assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value());
assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name());
- assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value());
+ assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value());
+
+ assertEquals("test_dynamic_model", config.rankprofile(5).name());
+ assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name());
+ assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value());
+
+ assertEquals("test_dynamic_model_2", config.rankprofile(6).name());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name());
+ assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value());
+
+ assertEquals("test_unbound_model", config.rankprofile(7).name());
+ assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(0).name());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(7).fef().property(3).name());
+ assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(7).fef().property(3).value());
+
+
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 6bf69907609..40bf970a313 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -70,7 +70,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
- "onnx('mnist_softmax.onnx')",
+ "onnx_vespa('mnist_softmax.onnx')",
"constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -87,7 +87,7 @@ public class RankingExpressionWithOnnxTestCase {
queryProfile,
queryProfileType);
RankProfileSearchFixture search = fixtureWith("query(mytensor)",
- "onnx('mnist_softmax.onnx')",
+ "onnx_vespa('mnist_softmax.onnx')",
null,
null,
"Placeholder",
@@ -99,7 +99,7 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceWithDocumentFeature() {
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
- "onnx('mnist_softmax.onnx')",
+ "onnx_vespa('mnist_softmax.onnx')",
null,
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
@@ -117,7 +117,7 @@ public class RankingExpressionWithOnnxTestCase {
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType);
RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)",
- "onnx('mnist_softmax.onnx')",
+ "onnx_vespa('mnist_softmax.onnx')",
"constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
@@ -129,21 +129,21 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testNestedOnnxReference() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "5 + sum(onnx('mnist_softmax.onnx'))");
+ "5 + sum(onnx_vespa('mnist_softmax.onnx'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
}
@Test
public void testOnnxReferenceWithSpecifiedOutput() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'layer_add')");
+ "onnx_vespa('mnist_softmax.onnx', 'layer_add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
public void testOnnxReferenceWithSpecifiedOutputAndSignature() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'default.layer_add')");
+ "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -155,7 +155,7 @@ public class RankingExpressionWithOnnxTestCase {
new QueryProfileRegistry(),
" rank-profile my_profile {\n" +
" first-phase {\n" +
- " expression: onnx('mnist_softmax.onnx')" +
+ " expression: onnx_vespa('mnist_softmax.onnx')" +
" }\n" +
" }");
search.compileRankProfile("my_profile", applicationDir.append("models"));
@@ -164,7 +164,7 @@ public class RankingExpressionWithOnnxTestCase {
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx('mnist_softmax.onnx'): " +
+ "onnx_vespa('mnist_softmax.onnx'): " +
"Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
@@ -175,13 +175,13 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceWithWrongFunctionType() {
try {
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)",
- "onnx('mnist_softmax.onnx')");
+ "onnx_vespa('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx('mnist_softmax.onnx'): " +
+ "onnx_vespa('mnist_softmax.onnx'): " +
"Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " +
"but this function returns tensor(d0[1],d5[10])",
Exceptions.toMessageString(expected));
@@ -192,13 +192,13 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceSpecifyingNonExistingOutput() {
try {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'y')");
+ "onnx_vespa('mnist_softmax.onnx', 'y')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx('mnist_softmax.onnx','y'): " +
+ "onnx_vespa('mnist_softmax.onnx','y'): " +
"No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add",
Exceptions.toMessageString(expected));
}
@@ -207,7 +207,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testImportingFromStoredExpressions() throws IOException {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx')");
+ "onnx_vespa('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
// At this point the expression is stored - copy application to another location which do not have a models dir
@@ -218,7 +218,7 @@ public class RankingExpressionWithOnnxTestCase {
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
RankProfileSearchFixture searchFromStored = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx')",
+ "onnx_vespa('mnist_softmax.onnx')",
null,
null,
"Placeholder",
@@ -243,7 +243,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: tensor<float>(d1[10],d2[784])(0.0)\n" +
" }\n" +
" first-phase {\n" +
- " expression: onnx('mnist_softmax.onnx')" +
+ " expression: onnx_vespa('mnist_softmax.onnx')" +
" }\n" +
" }" +
" rank-profile my_profile_child inherits my_profile {\n" +
@@ -288,7 +288,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: tensor<float>(d0[3])(0.0)\n" +
" }\n" +
" first-phase {\n" +
- " expression: onnx('" + name + ".onnx')" +
+ " expression: onnx_vespa('" + name + ".onnx')" +
" }\n" +
" }";
final String functionName = "imported_ml_function_" + name + "_exp_output";
@@ -310,7 +310,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: tensor<float>(d0[3])(0.0)\n" +
" }\n" +
" first-phase {\n" +
- " expression: onnx('" + name + ".onnx')" +
+ " expression: onnx_vespa('" + name + ".onnx')" +
" }\n" +
" }" +
" rank-profile my_profile_child inherits my_profile {\n" +
diff --git a/configdefinitions/src/vespa/stor-filestor.def b/configdefinitions/src/vespa/stor-filestor.def
index bf1b4294b5b..1cec77832a7 100644
--- a/configdefinitions/src/vespa/stor-filestor.def
+++ b/configdefinitions/src/vespa/stor-filestor.def
@@ -63,3 +63,11 @@ enable_merge_local_node_choose_docs_optimalization bool default=true restart
## if splitting is expensive, but listing document identifiers is fairly cheap.
## This is true for memfile persistence layer, but not for vespa search.
enable_multibit_split_optimalization bool default=true restart
+
+## Whether or not to use async message handling when scheduling storage messages from FileStorManager.
+##
+## When turned on, the calling thread (e.g. FNET network thread when using Storage API RPC)
+## gets the next async message to handle (if any) as part of scheduling a storage message.
+## This async message is then handled by the calling thread immediately,
+## instead of going via a persistence thread.
+use_async_message_handling_on_schedule bool default=false restart
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java
index a522e26a46d..7b4d82a9f53 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java
@@ -18,6 +18,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueHandl
import com.yahoo.vespa.hosted.controller.api.integration.organization.Mailer;
import com.yahoo.vespa.hosted.controller.api.integration.organization.OwnershipIssues;
import com.yahoo.vespa.hosted.controller.api.integration.organization.SystemMonitor;
+import com.yahoo.vespa.hosted.controller.api.integration.repair.HostRepairClient;
import com.yahoo.vespa.hosted.controller.api.integration.resource.CostReportConsumer;
import com.yahoo.vespa.hosted.controller.api.integration.resource.MeteringClient;
import com.yahoo.vespa.hosted.controller.api.integration.routing.GlobalRoutingService;
@@ -79,4 +80,5 @@ public interface ServiceRegistry {
BillingController billingController();
+ HostRepairClient hostRepairClient();
}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/ZmsClientMock.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/ZmsClientMock.java
index ec5d62569f6..942f0f35f58 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/ZmsClientMock.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/ZmsClientMock.java
@@ -136,6 +136,17 @@ public class ZmsClientMock implements ZmsClient {
}
@Override
+ public void addPolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole) {
+
+ }
+
+ @Override
+ public boolean deletePolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole) {
+ return false;
+ }
+
+
+ @Override
public void close() {}
private static AthenzDomain getTenantDomain(AthenzResourceName resource) {
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/Node.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/Node.java
index 07e411cd5cd..b57b2dbc496 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/Node.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/Node.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.controller.api.integration.configserver;
+import com.fasterxml.jackson.databind.JsonNode;
import com.yahoo.component.Version;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.DockerImage;
@@ -10,6 +11,8 @@ import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.TenantName;
import java.time.Instant;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Objects;
import java.util.Optional;
@@ -48,13 +51,14 @@ public class Node {
private final boolean wantToRetire;
private final boolean wantToDeprovision;
private final Optional<TenantName> reservedTo;
+ private final Map<String, JsonNode> reports;
public Node(HostName hostname, Optional<HostName> parentHostname, State state, NodeType type, NodeResources resources, Optional<ApplicationId> owner,
Version currentVersion, Version wantedVersion, Version currentOsVersion, Version wantedOsVersion,
Optional<Instant> currentFirmwareCheck, Optional<Instant> wantedFirmwareCheck, ServiceState serviceState,
Optional<Instant> suspendedSince, long restartGeneration, long wantedRestartGeneration, long rebootGeneration, long wantedRebootGeneration,
int cost, String flavor, String clusterId, ClusterType clusterType, boolean wantToRetire, boolean wantToDeprovision,
- Optional<TenantName> reservedTo, DockerImage wantedDockerImage, DockerImage currentDockerImage) {
+ Optional<TenantName> reservedTo, DockerImage wantedDockerImage, DockerImage currentDockerImage, Map<String, JsonNode> reports) {
this.hostname = hostname;
this.parentHostname = parentHostname;
this.state = state;
@@ -82,6 +86,7 @@ public class Node {
this.reservedTo = reservedTo;
this.wantedDockerImage = wantedDockerImage;
this.currentDockerImage = currentDockerImage;
+ this.reports = reports;
}
public HostName hostname() {
@@ -188,6 +193,10 @@ public class Node {
public Optional<TenantName> reservedTo() { return reservedTo; }
+ public Map<String, JsonNode> reports() {
+ return reports;
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
@@ -258,6 +267,7 @@ public class Node {
private boolean wantToRetire;
private boolean wantToDeprovision;
private Optional<TenantName> reservedTo = Optional.empty();
+ private Map<String, JsonNode> reports = new HashMap<>();
public Builder() { }
@@ -289,6 +299,7 @@ public class Node {
this.wantToRetire = node.wantToRetire;
this.wantToDeprovision = node.wantToDeprovision;
this.reservedTo = node.reservedTo;
+ this.reports = node.reports;
}
public Builder hostname(HostName hostname) {
@@ -431,7 +442,7 @@ public class Node {
currentOsVersion, wantedOsVersion, currentFirmwareCheck, wantedFirmwareCheck, serviceState,
suspendedSince, restartGeneration, wantedRestartGeneration, rebootGeneration, wantedRebootGeneration,
cost, flavor, clusterId, clusterType, wantToRetire, wantToDeprovision, reservedTo,
- wantedDockerImage, currentDockerImage);
+ wantedDockerImage, currentDockerImage, reports);
}
}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/NodeRepository.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/NodeRepository.java
index aebfab7cbff..6f4b39ac9b9 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/NodeRepository.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/NodeRepository.java
@@ -90,6 +90,8 @@ public interface NodeRepository {
void retireAndDeprovision(ZoneId zoneId, String hostName);
+ void patchNode(ZoneId zoneId, String hostName, NodeRepositoryNode node);
+
private static Node toNode(NodeRepositoryNode node) {
var application = Optional.ofNullable(node.getOwner())
.map(owner -> ApplicationId.from(owner.getTenant(), owner.getApplication(),
@@ -128,7 +130,8 @@ public interface NodeRepository {
node.getWantToDeprovision(),
Optional.ofNullable(node.getReservedTo()).map(TenantName::from),
dockerImageFrom(node.getWantedDockerImage()),
- dockerImageFrom(node.getCurrentDockerImage()));
+ dockerImageFrom(node.getCurrentDockerImage()),
+ node.getReports());
}
private static String clusterIdOf(NodeMembership nodeMembership) {
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/HostRepairClient.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/HostRepairClient.java
new file mode 100644
index 00000000000..a4a5a773cb9
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/HostRepairClient.java
@@ -0,0 +1,23 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.integration.repair;
+
+import com.yahoo.config.provision.HostName;
+import com.yahoo.config.provision.zone.ZoneApi;
+import com.yahoo.config.provision.zone.ZoneId;
+import com.yahoo.vespa.hosted.controller.api.integration.configserver.Node;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @author olaa
+ */
+public interface HostRepairClient {
+
+ /* Checks current ticket status and takes appropriate action */
+ void updateRepairStatus(ZoneApi zone, Map<Node, RepairTicketReport> nodes);
+
+ /* Creates reparation ticket for given host. Returns ticket number */
+ String createTicket(HostName hostname, String colo, ZoneId zoneId, String description, String category);
+
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/MockRepairClient.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/MockRepairClient.java
new file mode 100644
index 00000000000..6ceceda5712
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/MockRepairClient.java
@@ -0,0 +1,33 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.integration.repair;
+
+import com.yahoo.config.provision.HostName;
+import com.yahoo.config.provision.zone.ZoneApi;
+import com.yahoo.config.provision.zone.ZoneId;
+import com.yahoo.vespa.hosted.controller.api.integration.configserver.Node;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @author olaa
+ */
+public class MockRepairClient implements HostRepairClient {
+
+ List<Node> updatedNodes = new ArrayList<>();
+
+ @Override
+ public void updateRepairStatus(ZoneApi zone, Map<Node, RepairTicketReport> nodes) {
+ updatedNodes.addAll(nodes.keySet());
+ }
+
+ @Override
+ public String createTicket(HostName hostname, String colo, ZoneId zoneId, String description, String category) {
+ throw new UnsupportedOperationException("Not implemented");
+ }
+
+ public List<Node> getUpdatedNodes() {
+ return updatedNodes;
+ }
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/RepairTicketReport.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/RepairTicketReport.java
new file mode 100644
index 00000000000..c2425fe0f72
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/RepairTicketReport.java
@@ -0,0 +1,63 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.integration.repair;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import static com.yahoo.yolean.Exceptions.uncheck;
+
+/**
+ * @author olaa
+ */
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class RepairTicketReport {
+
+ private static final String REPORT_ID = "repairTicket";
+ private static final ObjectMapper objectMapper = new ObjectMapper();
+
+ public String status;
+ public String ticketNumber;
+ public long createdMillis;
+ public long updatedMillis;
+
+ public RepairTicketReport(@JsonProperty("status") String status,
+ @JsonProperty("ticketNumber") String ticketNumber,
+ @JsonProperty("createdMillis") long createdMillis,
+ @JsonProperty("updatedMillis") long updatedMillis) {
+ this.status = status;
+ this.ticketNumber = ticketNumber;
+ this.createdMillis = createdMillis;
+ this.updatedMillis = updatedMillis;
+ }
+
+ public String getStatus() {
+ return status;
+ }
+
+ public String getTicketNumber() {
+ return ticketNumber;
+ }
+
+ public long getCreatedMillis() {
+ return createdMillis;
+ }
+
+ public long getUpdatedMillis() {
+ return updatedMillis;
+ }
+
+ public static String getReportId() {
+ return REPORT_ID;
+ }
+
+ public static RepairTicketReport fromJsonNode(JsonNode node) {
+ return uncheck(() -> objectMapper.treeToValue(node, RepairTicketReport.class));
+ }
+
+ public JsonNode toJsonNode() {
+ return uncheck(() -> objectMapper.valueToTree(this));
+ }
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/package-info.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/package-info.java
new file mode 100644
index 00000000000..f53cb1ee43c
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/repair/package-info.java
@@ -0,0 +1,5 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+package com.yahoo.vespa.hosted.controller.api.integration.repair;
+
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
index 5970494d471..a09dc0589ed 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
@@ -34,6 +34,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.aws.ApplicationRoles;
import com.yahoo.vespa.hosted.controller.api.integration.billing.BillingController;
import com.yahoo.vespa.hosted.controller.api.integration.billing.Quota;
import com.yahoo.vespa.hosted.controller.api.integration.certificates.EndpointCertificateMetadata;
+import com.yahoo.vespa.hosted.controller.api.integration.configserver.Cluster;
import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServer;
import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServerException;
import com.yahoo.vespa.hosted.controller.api.integration.configserver.ContainerEndpoint;
@@ -353,15 +354,27 @@ public class ApplicationController {
// Carry out deployment without holding the application lock.
ActivateResult result = deploy(job.application(), applicationPackage, zone, platform, endpoints, endpointCertificateMetadata, applicationRoles);
+ // Record the quota usage for this application
+ var quotaUsage = deploymentQuotaUsage(zone, job.application());
+
lockApplicationOrThrow(applicationId, application ->
store(application.with(job.application().instance(),
instance -> instance.withNewDeployment(zone, revision, platform,
clock.instant(), warningsFrom(result),
- QuotaUsage.create(result.quotaUsageRate())))));
+ quotaUsage))));
return result;
}
}
+ private QuotaUsage deploymentQuotaUsage(ZoneId zoneId, ApplicationId applicationId) {
+ var quotaUsage = configServer.nodeRepository().getApplication(zoneId, applicationId)
+ .clusters().values().stream()
+ .map(Cluster::max)
+ .mapToDouble(max -> max.nodes() * max.nodeResources().cost())
+ .sum();
+ return QuotaUsage.create(quotaUsage);
+ }
+
private ApplicationPackage getApplicationPackage(ApplicationId application, ZoneId zone, ApplicationVersion revision) {
return new ApplicationPackage(revision.isUnknown() ? applicationStore.getDev(application, zone)
: applicationStore.get(application.tenant(), application.application(), revision));
@@ -429,11 +442,14 @@ public class ApplicationController {
ActivateResult result = deploy(instanceId, applicationPackage, zone, platformVersion,
endpoints, endpointCertificateMetadata, Optional.empty());
+ // Record the quota usage for this application
+ var quotaUsage = deploymentQuotaUsage(zone, instanceId);
+
lockApplicationOrThrow(applicationId, application ->
store(application.with(instanceId.instance(),
instance -> instance.withNewDeployment(zone, applicationVersion, platformVersion,
clock.instant(), warningsFrom(result),
- QuotaUsage.create(result.quotaUsageRate())))));
+ quotaUsage))));
return result;
}
}
@@ -547,10 +563,8 @@ public class ApplicationController {
endpoints, endpointCertificateMetadata, dockerImageRepo, domain,
applicationRoles, quota));
- var quotaUsage = configServer.getQuotaUsage(new DeploymentId(application, zone));
-
return new ActivateResult(new RevisionId(applicationPackage.hash()), preparedApplication.prepareResponse(),
- applicationPackage.zippedContent().length, quotaUsage.rate);
+ applicationPackage.zippedContent().length);
} finally {
// Even if prepare fails, a load balancer may have been provisioned. Always refresh routing policies so that
// any DNS updates can be propagated as early as possible.
@@ -567,7 +581,7 @@ public class ApplicationController {
PrepareResponse prepareResponse = new PrepareResponse();
prepareResponse.log = List.of(logEntry);
prepareResponse.configChangeActions = new ConfigChangeActions(List.of(), List.of());
- return new ActivateResult(new RevisionId("0"), prepareResponse, 0, 0.0);
+ return new ActivateResult(new RevisionId("0"), prepareResponse, 0);
}
private LockedApplication withoutDeletedDeployments(LockedApplication application, InstanceName instance) {
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java
index e6c9e52ff69..5379a08afc0 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/api/ActivateResult.java
@@ -13,13 +13,11 @@ public class ActivateResult {
private final RevisionId revisionId;
private final PrepareResponse prepareResponse;
private final long applicationZipSizeBytes;
- private final double quotaUsageRate;
- public ActivateResult(RevisionId revisionId, PrepareResponse prepareResponse, long applicationZipSizeBytes, double quotaUsageRate) {
+ public ActivateResult(RevisionId revisionId, PrepareResponse prepareResponse, long applicationZipSizeBytes) {
this.revisionId = revisionId;
this.prepareResponse = prepareResponse;
this.applicationZipSizeBytes = applicationZipSizeBytes;
- this.quotaUsageRate = quotaUsageRate;
}
public long applicationZipSizeBytes() {
@@ -33,9 +31,4 @@ public class ActivateResult {
public PrepareResponse prepareResponse() {
return prepareResponse;
}
-
- public double quotaUsageRate() {
- return quotaUsageRate;
- }
-
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java
index 124b913eb01..b4904ca3cf8 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java
@@ -646,12 +646,10 @@ public class DeploymentStatus {
@Override
public Optional<Instant> completedAt(Change change, Optional<JobId> dependent) {
return RunList.from(job)
- .matching(run -> change.platform().map(run.versions().targetPlatform()::equals).orElse(true))
- .matching(run -> change.application().map(run.versions().targetApplication()::equals).orElse(true))
- .matching(run -> dependent.flatMap(status::deploymentFor)
- .map(deployment -> Versions.from(change, deployment))
- .map(run.versions()::targetsMatch)
- .orElse(true))
+ .matching(run -> run.versions().targetsMatch(Versions.from(change,
+ status.application,
+ dependent.flatMap(status::deploymentFor),
+ status.systemVersion)))
.status(RunStatus.success)
.asList().stream()
.map(run -> run.end().get())
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java
index 0e72a1b42a7..6731c30ecd7 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java
@@ -45,6 +45,8 @@ public class ControllerMaintenance extends AbstractComponent {
private final ResourceTagMaintainer resourceTagMaintainer;
private final SystemRoutingPolicyMaintainer systemRoutingPolicyMaintainer;
private final ApplicationMetaDataGarbageCollector applicationMetaDataGarbageCollector;
+ private final HostRepairMaintainer hostRepairMaintainer;
+
@Inject
@SuppressWarnings("unused") // instantiated by Dependency Injection
@@ -75,6 +77,7 @@ public class ControllerMaintenance extends AbstractComponent {
resourceTagMaintainer = new ResourceTagMaintainer(controller, Duration.ofMinutes(30), controller.serviceRegistry().resourceTagger());
systemRoutingPolicyMaintainer = new SystemRoutingPolicyMaintainer(controller, Duration.ofMinutes(10));
applicationMetaDataGarbageCollector = new ApplicationMetaDataGarbageCollector(controller, Duration.ofHours(12));
+ hostRepairMaintainer = new HostRepairMaintainer(controller, Duration.ofHours(12));
}
public Upgrader upgrader() { return upgrader; }
@@ -102,6 +105,7 @@ public class ControllerMaintenance extends AbstractComponent {
rotationStatusUpdater.close();
resourceTagMaintainer.close();
systemRoutingPolicyMaintainer.close();
+ hostRepairMaintainer.close();
}
/** Create one OS upgrader per cloud found in the zone registry of controller */
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirer.java
index 7bd2c737fcb..37de7369452 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirer.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentExpirer.java
@@ -25,10 +25,10 @@ public class DeploymentExpirer extends ControllerMaintainer {
@Override
protected boolean maintain() {
boolean success = true;
- for (Application application : controller().applications().readable())
+ for (Application application : controller().applications().readable()) {
for (Instance instance : application.instances().values())
for (Deployment deployment : instance.deployments().values()) {
- if ( ! isExpired(deployment)) continue;
+ if (!isExpired(deployment)) continue;
try {
log.log(Level.INFO, "Expiring deployment of " + instance.id() + " in " + deployment.zone());
@@ -40,6 +40,7 @@ public class DeploymentExpirer extends ControllerMaintainer {
interval());
}
}
+ }
return success;
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainer.java
new file mode 100644
index 00000000000..e3c6862384f
--- /dev/null
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainer.java
@@ -0,0 +1,81 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.maintenance;
+
+import com.yahoo.config.provision.CloudName;
+import com.yahoo.config.provision.SystemName;
+import com.yahoo.config.provision.zone.ZoneApi;
+import com.yahoo.vespa.hosted.controller.Controller;
+import com.yahoo.vespa.hosted.controller.api.integration.configserver.Node;
+import com.yahoo.vespa.hosted.controller.api.integration.configserver.NodeRepository;
+import com.yahoo.vespa.hosted.controller.api.integration.repair.RepairTicketReport;
+import com.yahoo.vespa.hosted.controller.api.integration.repair.HostRepairClient;
+import com.yahoo.yolean.Exceptions;
+
+import java.time.Duration;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Predicate;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+
+import static com.yahoo.yolean.Exceptions.uncheck;
+
+/**
+ *
+ * Responsible for keeping track of hosts under repair.
+ *
+ * @author olaa
+ */
+public class HostRepairMaintainer extends ControllerMaintainer {
+
+ private final NodeRepository nodeRepository;
+ private final HostRepairClient repairClient;
+
+ private static final Logger log = Logger.getLogger(HostRepairMaintainer.class.getName());
+
+
+ public HostRepairMaintainer(Controller controller, Duration interval) {
+ super(controller, interval, null, SystemName.allOf(Predicate.not(SystemName::isPublic)));
+ this.nodeRepository = controller.serviceRegistry().configServer().nodeRepository();
+ this.repairClient = controller.serviceRegistry().hostRepairClient();
+ }
+
+
+ @Override
+ protected boolean maintain() {
+ AtomicInteger exceptions = new AtomicInteger(0);
+
+ controller().zoneRegistry().zones()
+ .reachable().zones().stream()
+ .forEach(zoneApi -> {
+ var nodeTicketMap = nodeRepository.list((zoneApi).getId())
+ .stream()
+ .filter(this::hasOpenTicket)
+ .collect(Collectors.toMap(
+ node -> node,
+ this::getTicketReport)
+ );
+ try {
+ repairClient.updateRepairStatus(zoneApi, nodeTicketMap);
+ } catch (Exception e) {
+ log.warning("Failed to update repair status; " + Exceptions.toMessageString(e));
+ exceptions.incrementAndGet();
+ }
+ }
+ );
+
+ return exceptions.get() == 0;
+ }
+
+
+ private boolean hasOpenTicket(Node node) {
+ var reports = node.reports();
+ if (!reports.containsKey(RepairTicketReport.getReportId())) {
+ return false;
+ }
+ return "OPEN".equals(getTicketReport(node).getStatus());
+ }
+
+ private RepairTicketReport getTicketReport(Node node) {
+ return uncheck(() -> RepairTicketReport.fromJsonNode(node.reports().get(RepairTicketReport.getReportId())));
+ }
+}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java
index f1306b51b39..6e5a9ddc7ab 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java
@@ -24,6 +24,7 @@ import java.time.Instant;
import java.util.Collection;
import java.util.EnumSet;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.stream.Collectors;
@@ -1216,4 +1217,42 @@ public class DeploymentTriggerTest {
app.assertNotRunning(stagingTest);
}
+ @Test
+ public void testTriggeringOfIdleTestJobsWhenFirstDeploymentIsOnNewerVersionThanChange() {
+ ApplicationPackage applicationPackage = new ApplicationPackageBuilder().systemTest()
+ .stagingTest()
+ .region("us-east-3")
+ .region("us-west-1")
+ .build();
+ var app = tester.newDeploymentContext().submit(applicationPackage).deploy();
+ var appToAvoidVersionGC = tester.newDeploymentContext("g", "c", "default").submit().deploy();
+
+ Version version2 = new Version("7.8.9");
+ Version version3 = new Version("8.9.10");
+ tester.controllerTester().upgradeSystem(version2);
+ tester.deploymentTrigger().triggerChange(appToAvoidVersionGC.instanceId(), Change.of(version2));
+ appToAvoidVersionGC.deployPlatform(version2);
+
+ // app upgrades first zone to version3, and then the other two to version2.
+ tester.controllerTester().upgradeSystem(version3);
+ tester.deploymentTrigger().triggerChange(app.instanceId(), Change.of(version3));
+ app.runJob(systemTest).runJob(stagingTest);
+ tester.triggerJobs();
+ tester.upgrader().overrideConfidence(version3, VespaVersion.Confidence.broken);
+ tester.controllerTester().computeVersionStatus();
+ tester.upgrader().run();
+ assertEquals(Optional.of(version2), app.instance().change().platform());
+
+ app.runJob(systemTest)
+ .runJob(productionUsEast3)
+ .runJob(stagingTest)
+ .runJob(productionUsWest1);
+
+ assertEquals(version3, app.instanceJobs().get(productionUsEast3).lastSuccess().get().versions().targetPlatform());
+ assertEquals(version2, app.instanceJobs().get(productionUsWest1).lastSuccess().get().versions().targetPlatform());
+ assertEquals(Map.of(), app.deploymentStatus().jobsToRun());
+ assertEquals(Change.empty(), app.instance().change());
+ assertEquals(List.of(), tester.jobs().active());
+ }
+
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/QuotaUsageTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/QuotaUsageTest.java
new file mode 100644
index 00000000000..d2901aeac97
--- /dev/null
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/QuotaUsageTest.java
@@ -0,0 +1,21 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.deployment;
+
+import com.yahoo.config.provision.zone.ZoneId;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author ogronnesby
+ */
+public class QuotaUsageTest {
+
+ @Test
+ public void testQuotaUsageIsPersisted() {
+ var tester = new DeploymentTester();
+ var context = tester.newDeploymentContext().submit().deploy();
+ assertEquals(1.062, context.deployment(ZoneId.from("prod.us-west-1")).quota().rate(), 0.01);
+ }
+
+}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/NodeRepositoryMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/NodeRepositoryMock.java
index 90276b6b590..72cc000ef98 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/NodeRepositoryMock.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/NodeRepositoryMock.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.controller.integration;
+import com.fasterxml.jackson.databind.JsonNode;
import com.yahoo.component.Version;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.HostName;
@@ -226,6 +227,11 @@ public class NodeRepositoryMock implements NodeRepository {
nodeRepository.get(zoneId).remove(HostName.from(hostName));
}
+ @Override
+ public void patchNode(ZoneId zoneId, String hostName, NodeRepositoryNode node) {
+ throw new UnsupportedOperationException();
+ }
+
public Optional<Duration> osUpgradeBudget(ZoneId zone, NodeType type, Version version) {
return Optional.ofNullable(osUpgradeBudgets.get(Objects.hash(zone, type, version)));
}
@@ -264,4 +270,8 @@ public class NodeRepositoryMock implements NodeRepository {
modifyNodes(deployment, hostname, node -> new Node.Builder(node).rebootGeneration(node.rebootGeneration() + 1).build());
}
+ public void addReport(ZoneId zoneId, HostName hostName, String reportId, JsonNode report) {
+ nodeRepository.get(zoneId).get(hostName).reports().put(reportId, report);
+ }
+
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java
index 1b21f7db7c4..3ec02c6ceb7 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java
@@ -20,6 +20,8 @@ import com.yahoo.vespa.hosted.controller.api.integration.dns.MemoryNameService;
import com.yahoo.vespa.hosted.controller.api.integration.entity.MemoryEntityService;
import com.yahoo.vespa.hosted.controller.api.integration.organization.MockContactRetriever;
import com.yahoo.vespa.hosted.controller.api.integration.organization.MockIssueHandler;
+import com.yahoo.vespa.hosted.controller.api.integration.repair.MockRepairClient;
+import com.yahoo.vespa.hosted.controller.api.integration.repair.HostRepairClient;
import com.yahoo.vespa.hosted.controller.api.integration.resource.CostReportConsumerMock;
import com.yahoo.vespa.hosted.controller.api.integration.routing.GlobalRoutingService;
import com.yahoo.vespa.hosted.controller.api.integration.routing.MemoryGlobalRoutingService;
@@ -61,6 +63,7 @@ public class ServiceRegistryMock extends AbstractComponent implements ServiceReg
private final MockResourceTagger mockResourceTagger = new MockResourceTagger();
private final ApplicationRoleService applicationRoleService = new NoopApplicationRoleService();
private final BillingController billingController = new MockBillingController();
+ private final MockRepairClient repairClient = new MockRepairClient();
public ServiceRegistryMock(SystemName system) {
this.zoneRegistryMock = new ZoneRegistryMock(system);
@@ -192,6 +195,11 @@ public class ServiceRegistryMock extends AbstractComponent implements ServiceReg
return billingController;
}
+ @Override
+ public MockRepairClient hostRepairClient() {
+ return repairClient;
+ }
+
public ConfigServerMock configServerMock() {
return configServerMock;
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainerTest.java
new file mode 100644
index 00000000000..556755581fe
--- /dev/null
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/HostRepairMaintainerTest.java
@@ -0,0 +1,51 @@
+package com.yahoo.vespa.hosted.controller.maintenance;
+
+import com.yahoo.config.provision.HostName;
+import com.yahoo.config.provision.zone.ZoneId;
+import com.yahoo.vespa.hosted.controller.ControllerTester;
+import com.yahoo.vespa.hosted.controller.api.integration.noderepository.NodeRepositoryNode;
+import com.yahoo.vespa.hosted.controller.api.integration.repair.HostRepairClient;
+import com.yahoo.vespa.hosted.controller.api.integration.repair.MockRepairClient;
+import com.yahoo.vespa.hosted.controller.api.integration.repair.RepairTicketReport;
+import org.junit.Test;
+
+import java.time.Duration;
+import java.time.Instant;
+import java.util.List;
+
+import static org.junit.Assert.*;
+
+/**
+ * @author olaa
+ */
+public class HostRepairMaintainerTest {
+
+ private final ControllerTester tester = new ControllerTester();
+ private final HostRepairMaintainer maintainer = new HostRepairMaintainer(tester.controller(), Duration.ofHours(12));
+
+ @Test
+ public void maintain() {
+ var zoneId = ZoneId.from("dev.us-east-1");
+ var hostname1 = HostName.from("node-1-tenant-host-dev.us-east-1");
+ var hostname2 = HostName.from("node-2-tenant-host-dev.us-east-1");
+ var timestamp = Instant.now().toEpochMilli();
+ var openTicket = new RepairTicketReport("OPEN", "ticket-1", timestamp, timestamp);
+ var closedTicket = new RepairTicketReport("CLOSED", "ticket-2", timestamp, timestamp);
+
+ tester.configServer().nodeRepository().addReport(
+ zoneId,
+ hostname1,
+ RepairTicketReport.getReportId(),
+ openTicket.toJsonNode());
+ tester.configServer().nodeRepository().addReport(
+ zoneId,
+ hostname2,
+ RepairTicketReport.getReportId(),
+ closedTicket.toJsonNode());
+
+ maintainer.maintain();
+ var updatedNodes = tester.serviceRegistry().hostRepairClient().getUpdatedNodes();
+ assertEquals(1, updatedNodes.size());
+ assertEquals(hostname1, updatedNodes.get(0).hostname());
+ }
+} \ No newline at end of file
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json
index 385f0fbc3cf..bb3578b2482 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json
@@ -28,6 +28,9 @@
"name": "DeploymentMetricsMaintainer"
},
{
+ "name": "HostRepairMaintainer"
+ },
+ {
"name": "JobRunner"
},
{
diff --git a/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp b/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp
index 0299dc3ebba..7182d66f8aa 100644
--- a/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp
+++ b/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp
@@ -36,7 +36,6 @@
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/default_value_builder_factory.h>
-#include <vespa/eval/tensor/mixed/packed_mixed_tensor_builder_factory.h>
#include <vespa/vespalib/util/benchmark_timer.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/objects/nbostream.h>
@@ -230,7 +229,6 @@ Impl default_tensor_engine_impl(1, "DefaultTensorEngine", "OLD PROD", DefaultTe
Impl simple_value_impl(3, " SimpleValue", " SimpleV", SimpleValueBuilderFactory::get(), false);
Impl fast_value_impl(0, " FastValue", "NEW PROD", FastValueBuilderFactory::get(), false);
Impl optimized_fast_value_impl(2, "Optimized FastValue", "Optimize", FastValueBuilderFactory::get(), true);
-Impl packed_mixed_tensor_impl(5, " PackedMixedTensor", " Packed", PackedMixedTensorBuilderFactory::get(), false);
Impl default_tensor_value_impl(4, " DefaultValue", "DefaultV", DefaultValueBuilderFactory::get(), false);
vespalib::string short_header("--------");
@@ -243,7 +241,6 @@ std::vector<CREF<Impl>> impl_list = {default_tensor_engine_impl,
simple_value_impl,
fast_value_impl,
optimized_fast_value_impl,
- packed_mixed_tensor_impl,
default_tensor_value_impl};
//-----------------------------------------------------------------------------
@@ -982,6 +979,14 @@ void print_summary() {
}
int main(int argc, char **argv) {
+ const std::string run_only_prod_option = "--limit-implementations";
+ if ((argc > 1) && (argv[1] == run_only_prod_option )) {
+ impl_list.clear();
+ impl_list.push_back(fast_value_impl);
+ impl_list.push_back(default_tensor_engine_impl);
+ ++argv;
+ --argc;
+ }
::testing::InitGoogleTest(&argc, argv);
int result = RUN_ALL_TESTS();
print_summary();
diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml
index 4beaf6086a6..181ef6dffbd 100644
--- a/fat-model-dependencies/pom.xml
+++ b/fat-model-dependencies/pom.xml
@@ -221,5 +221,10 @@
<artifactId>jdisc_http_service</artifactId>
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${protobuf.version}</version>
+ </dependency>
</dependencies>
</project>
diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
index c56de1bb178..dd6d84e3ad7 100644
--- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
+++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
@@ -161,7 +161,7 @@ public class Flags {
ZONE_ID, APPLICATION_ID);
public static final UnboundBooleanFlag USE_CONTENT_NODE_BTREE_DB = defineFeatureFlag(
- "use-content-node-btree-db", false,
+ "use-content-node-btree-db", true,
"Whether to use the new B-tree bucket database on the content node.",
"Takes effect at restart of content node process",
ZONE_ID, APPLICATION_ID);
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java
index c73a19bd9e2..eace7457615 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java
@@ -61,7 +61,7 @@ public class Autoscaler {
private Optional<AllocatableClusterResources> autoscale(Cluster cluster,
List<Node> clusterNodes, Limits limits, boolean exclusive) {
- if (unstable(clusterNodes)) return Optional.empty();
+ if (unstable(clusterNodes, nodeRepository)) return Optional.empty();
AllocatableClusterResources currentAllocation = new AllocatableClusterResources(clusterNodes, nodeRepository);
@@ -111,10 +111,18 @@ public class Autoscaler {
return 20;
}
- public static boolean unstable(List<Node> nodes) {
- return nodes.stream().anyMatch(node -> node.status().wantToRetire() ||
- node.allocation().get().membership().retired() ||
- node.allocation().get().isRemovable());
+ public static boolean unstable(List<Node> nodes, NodeRepository nodeRepository) {
+ // The cluster is processing recent changes
+ if (nodes.stream().anyMatch(node -> node.status().wantToRetire() ||
+ node.allocation().get().membership().retired() ||
+ node.allocation().get().isRemovable()))
+ return true;
+
+ // A deployment is ongoing
+ if (nodeRepository.getNodes(nodes.get(0).allocation().get().owner(), Node.State.reserved).size() > 0)
+ return true;
+
+ return false;
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcher.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcher.java
index b4a63175548..4597fc04e17 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcher.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcher.java
@@ -56,7 +56,7 @@ public class MetricsV2MetricsFetcher extends AbstractComponent implements Metric
NodeList applicationNodes = nodeRepository.list(application).state(Node.State.active);
// Do not try to draw conclusions from utilization while unstable
- if (Autoscaler.unstable(applicationNodes.asList())) return Collections.emptyList();
+ if (Autoscaler.unstable(applicationNodes.asList(), nodeRepository)) return Collections.emptyList();
Optional<Node> metricsV2Container = applicationNodes.container()
.matching(node -> expectedUp(node))
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java
index 3b01f678982..c0fd7df9b2e 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java
@@ -28,7 +28,6 @@ import java.util.stream.Collectors;
*/
public class AutoscalingMaintainer extends NodeRepositoryMaintainer {
- private final MetricsDb metricsDb;
private final Autoscaler autoscaler;
private final Deployer deployer;
private final Metric metric;
@@ -40,7 +39,6 @@ public class AutoscalingMaintainer extends NodeRepositoryMaintainer {
Duration interval) {
super(nodeRepository, interval, metric);
this.autoscaler = new Autoscaler(metricsDb, nodeRepository);
- this.metricsDb = metricsDb;
this.metric = metric;
this.deployer = deployer;
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java
index c9538d878f2..9ef5a841a7a 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainer.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.hosted.provision.maintenance;
import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.ApplicationLockException;
import com.yahoo.config.provision.ClusterResources;
import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.jdisc.Metric;
@@ -39,32 +40,39 @@ public class ScalingSuggestionsMaintainer extends NodeRepositoryMaintainer {
@Override
protected boolean maintain() {
- boolean success = true;
- if ( ! nodeRepository().zone().environment().isProduction()) return success;
+ if ( ! nodeRepository().zone().environment().isProduction()) return true;
- activeNodesByApplication().forEach((applicationId, nodes) -> suggest(applicationId, nodes));
- return success;
+ int successes = 0;
+ for (var application : activeNodesByApplication().entrySet())
+ successes += suggest(application.getKey(), application.getValue());
+ return successes > 0;
}
- private void suggest(ApplicationId application, List<Node> applicationNodes) {
- nodesByCluster(applicationNodes).forEach((clusterId, clusterNodes) ->
- suggest(application, clusterId, clusterNodes));
+ private int suggest(ApplicationId application, List<Node> applicationNodes) {
+ int successes = 0;
+ for (var cluster : nodesByCluster(applicationNodes).entrySet())
+ successes += suggest(application, cluster.getKey(), cluster.getValue()) ? 1 : 0;
+ return successes;
}
private Applications applications() {
return nodeRepository().applications();
}
- private void suggest(ApplicationId applicationId,
- ClusterSpec.Id clusterId,
- List<Node> clusterNodes) {
+ private boolean suggest(ApplicationId applicationId,
+ ClusterSpec.Id clusterId,
+ List<Node> clusterNodes) {
Application application = applications().get(applicationId).orElse(new Application(applicationId));
Optional<Cluster> cluster = application.cluster(clusterId);
- if (cluster.isEmpty()) return;
+ if (cluster.isEmpty()) return true;
Optional<ClusterResources> suggestion = autoscaler.suggest(cluster.get(), clusterNodes);
// Wait only a short time for the lock to avoid interfering with change deployments
try (Mutex lock = nodeRepository().lock(applicationId, Duration.ofSeconds(1))) {
applications().get(applicationId).ifPresent(a -> storeSuggestion(suggestion, clusterId, a, lock));
+ return true;
+ }
+ catch (ApplicationLockException e) {
+ return false;
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java
index 240963a8c0d..1e98160955c 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java
@@ -112,8 +112,11 @@ class NodeAllocation {
boolean resizeable = requestedNodes.considerRetiring() && candidate.isResizable;
boolean acceptToRetire = acceptToRetire(candidate);
- if ((! saturated() && hasCompatibleFlavor(candidate) && requestedNodes.acceptable(candidate)) || acceptToRetire)
- accepted.add(acceptNode(candidate, shouldRetire(candidate), resizeable));
+ if ((! saturated() && hasCompatibleFlavor(candidate) && requestedNodes.acceptable(candidate)) || acceptToRetire) {
+ candidate = candidate.withNode();
+ if (candidate.isValid())
+ accepted.add(acceptNode(candidate, shouldRetire(candidate), resizeable));
+ }
}
else if (! saturated() && hasCompatibleFlavor(candidate)) {
if ( ! nodeResourceLimits.isWithinRealLimits(candidate, cluster)) {
@@ -240,7 +243,6 @@ class NodeAllocation {
}
private Node acceptNode(NodeCandidate candidate, boolean wantToRetire, boolean resizeable) {
- candidate = candidate.withNode();
Node node = candidate.toNode();
if (node.allocation().isPresent()) // Record the currently requested resources
@@ -356,7 +358,7 @@ class NodeAllocation {
candidate = candidate.withNode();
Allocation allocation = candidate.allocation().get();
candidate = candidate.withNode(candidate.toNode().with(allocation.with(allocation.membership()
- .with(allocation.membership().cluster().exclusive(requestedNodes.isExclusive())))));
+ .with(allocation.membership().cluster().exclusive(requestedNodes.isExclusive())))));
nodes.put(candidate.toNode().hostname(), candidate);
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeCandidate.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeCandidate.java
index b915053fff5..02086e2bace 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeCandidate.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeCandidate.java
@@ -87,7 +87,11 @@ abstract class NodeCandidate implements Nodelike, Comparable<NodeCandidate> {
/** Returns a copy of this with exclusive switch set to given value */
public abstract NodeCandidate withExclusiveSwitch(boolean exclusiveSwitch);
- /** Returns the node instance of this candidate, or an invalid node if it cannot be created */
+ /**
+ * Returns the node instance of this candidate, allocating it if necessary.
+ *
+ * @throws IllegalStateException if the node candidate is invalid
+ */
public abstract Node toNode();
/** Returns whether this node can - as far as we know - be used to run the application workload */
@@ -358,10 +362,12 @@ abstract class NodeCandidate implements Nodelike, Comparable<NodeCandidate> {
Optional<IP.Allocation> allocation;
try {
allocation = parent.get().ipConfig().pool().findAllocation(allNodes, nodeRepository.nameResolver());
- if (allocation.isEmpty()) return new InvalidNodeCandidate(resources, freeParentCapacity, parent.get());
+ if (allocation.isEmpty()) return new InvalidNodeCandidate(resources, freeParentCapacity, parent.get(),
+ "No IP addresses available on parent host");
} catch (Exception e) {
log.warning("Failed allocating IP address on " + parent.get() +": " + Exceptions.toMessageString(e));
- return new InvalidNodeCandidate(resources, freeParentCapacity, parent.get());
+ return new InvalidNodeCandidate(resources, freeParentCapacity, parent.get(),
+ "Failed when allocating IP address on host");
}
Node node = Node.createDockerNode(allocation.get().addresses(),
@@ -409,10 +415,13 @@ abstract class NodeCandidate implements Nodelike, Comparable<NodeCandidate> {
static class InvalidNodeCandidate extends NodeCandidate {
private final NodeResources resources;
+ private final String invalidReason;
- private InvalidNodeCandidate(NodeResources resources, NodeResources freeParentCapacity, Node parent) {
+ private InvalidNodeCandidate(NodeResources resources, NodeResources freeParentCapacity, Node parent,
+ String invalidReason) {
super(freeParentCapacity, Optional.of(parent), false, false, false, true, false);
this.resources = resources;
+ this.invalidReason = invalidReason;
}
@Override
@@ -453,7 +462,7 @@ abstract class NodeCandidate implements Nodelike, Comparable<NodeCandidate> {
@Override
public Node toNode() {
- throw new IllegalStateException("Candidate node on " + parent.get() + " is invalid");
+ throw new IllegalStateException("Candidate node on " + parent.get() + " is invalid: " + invalidReason);
}
@Override
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java
index 3aa6253979d..de7adf9fa2d 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodePrioritizer.java
@@ -87,19 +87,14 @@ public class NodePrioritizer {
/** Returns the list of nodes sorted by {@link NodeCandidate#compareTo(NodeCandidate)} */
private List<NodeCandidate> prioritize() {
- // Group candidates by their cluster switch
- Map<ClusterSwitch, List<NodeCandidate>> candidatesBySwitch = this.nodes.stream().collect(Collectors.groupingBy(candidate -> {
- Node nodeOnSwitch = candidate.parent.orElseGet(candidate::toNode);
- ClusterSpec.Id cluster = candidate.toNode().allocation()
- .map(a -> a.membership().cluster().id())
- .orElseGet(clusterSpec::id);
- return ClusterSwitch.from(cluster, nodeOnSwitch.switchHostname());
- }));
+ // Group candidates by their switch hostname
+ Map<Optional<String>, List<NodeCandidate>> candidatesBySwitch = this.nodes.stream()
+ .collect(Collectors.groupingBy(candidate -> candidate.parent.orElseGet(candidate::toNode).switchHostname()));
// Mark lower priority nodes on shared switch as non-exclusive
List<NodeCandidate> nodes = new ArrayList<>(this.nodes.size());
for (var clusterSwitch : candidatesBySwitch.keySet()) {
List<NodeCandidate> switchCandidates = candidatesBySwitch.get(clusterSwitch);
- if (clusterSwitch.equals(ClusterSwitch.unknown)) {
+ if (clusterSwitch.isEmpty()) {
nodes.addAll(switchCandidates); // Nodes are on exclusive switch by default
} else {
Collections.sort(switchCandidates);
@@ -156,6 +151,7 @@ public class NodePrioritizer {
.filter(node -> legalStates.contains(node.state()))
.filter(node -> node.allocation().isPresent())
.filter(node -> node.allocation().get().owner().equals(application))
+ .filter(node -> node.allocation().get().membership().cluster().id().equals(clusterSpec.id()))
.filter(node -> node.state() == Node.State.active || canStillAllocateToParentOf(node))
.map(node -> candidateFrom(node, false))
.forEach(nodes::add);
@@ -206,43 +202,9 @@ public class NodePrioritizer {
*/
private boolean canStillAllocateToParentOf(Node node) {
if (node.parentHostname().isEmpty()) return true;
- Optional<Node> parent = node.parentHostname().flatMap(nodeRepository::getNode);
+ Optional<Node> parent = allNodes.parentOf(node);
if (parent.isEmpty()) return false;
return nodeRepository.canAllocateTenantNodeTo(parent.get());
}
- /** A cluster and its network switch */
- private static class ClusterSwitch {
-
- private static final ClusterSwitch unknown = new ClusterSwitch(null, null);
-
- private final ClusterSpec.Id cluster;
- private final String switchHostname;
-
- public ClusterSwitch(ClusterSpec.Id cluster, String switchHostname) {
- this.cluster = cluster;
- this.switchHostname = switchHostname;
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- ClusterSwitch that = (ClusterSwitch) o;
- return Objects.equals(cluster, that.cluster) &&
- Objects.equals(switchHostname, that.switchHostname);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(cluster, switchHostname);
- }
-
- public static ClusterSwitch from(ClusterSpec.Id cluster, Optional<String> switchHostname) {
- if (switchHostname.isEmpty()) return unknown;
- return new ClusterSwitch(cluster, switchHostname.get());
- }
-
- }
-
}
diff --git a/searchcore/src/tests/proton/attribute/attribute_populator/attribute_populator_test.cpp b/searchcore/src/tests/proton/attribute/attribute_populator/attribute_populator_test.cpp
index 683a6cd1197..2a5444c2525 100644
--- a/searchcore/src/tests/proton/attribute/attribute_populator/attribute_populator_test.cpp
+++ b/searchcore/src/tests/proton/attribute/attribute_populator/attribute_populator_test.cpp
@@ -97,12 +97,12 @@ TEST_F("require that reprocess with document populates attribute", Fixture)
f._pop->handleExisting(5, f._ctx.create(0, 33));
EXPECT_EQUAL(6u, attr->get()->getNumDocs());
EXPECT_EQUAL(33, attr->get()->getInt(5));
- EXPECT_EQUAL(1u, attr->get()->getStatus().getLastSyncToken());
+ EXPECT_EQUAL(0u, attr->get()->getStatus().getLastSyncToken());
f._pop->handleExisting(6, f._ctx.create(1, 44));
EXPECT_EQUAL(7u, attr->get()->getNumDocs());
EXPECT_EQUAL(44, attr->get()->getInt(6));
- EXPECT_EQUAL(2u, attr->get()->getStatus().getLastSyncToken());
+ EXPECT_EQUAL(0u, attr->get()->getStatus().getLastSyncToken());
f._pop->done();
EXPECT_EQUAL(CREATE_SERIAL_NUM, attr->get()->getStatus().getLastSyncToken());
}
diff --git a/searchcore/src/tests/proton/attribute/attribute_test.cpp b/searchcore/src/tests/proton/attribute/attribute_test.cpp
index b30420ead24..c98127f4daf 100644
--- a/searchcore/src/tests/proton/attribute/attribute_test.cpp
+++ b/searchcore/src/tests/proton/attribute/attribute_test.cpp
@@ -169,29 +169,34 @@ public:
_mgr->addAttribute(attr->getName(), std::move(attr));
allocAttributeWriter();
}
- void put(SerialNum serialNum, const Document &doc, DocumentIdT lid,
- bool immediateCommit = true) {
- _aw->put(serialNum, doc, lid, immediateCommit, emptyCallback);
+ void put(SerialNum serialNum, const Document &doc, DocumentIdT lid) {
+ _aw->put(serialNum, doc, lid, emptyCallback);
+ commit(serialNum);
}
void update(SerialNum serialNum, const DocumentUpdate &upd,
- DocumentIdT lid, bool immediateCommit, IFieldUpdateCallback & onUpdate) {
- _aw->update(serialNum, upd, lid, immediateCommit, emptyCallback, onUpdate);
+ DocumentIdT lid, IFieldUpdateCallback & onUpdate) {
+ _aw->update(serialNum, upd, lid, emptyCallback, onUpdate);
+ commit(serialNum);
}
- void update(SerialNum serialNum, const Document &doc,
- DocumentIdT lid, bool immediateCommit) {
- _aw->update(serialNum, doc, lid, immediateCommit, emptyCallback);
+ void update(SerialNum serialNum, const Document &doc, DocumentIdT lid) {
+ _aw->update(serialNum, doc, lid, emptyCallback);
+ commit(serialNum);
}
- void remove(SerialNum serialNum, DocumentIdT lid, bool immediateCommit = true) {
- _aw->remove(serialNum, lid, immediateCommit, emptyCallback);
+ void remove(SerialNum serialNum, DocumentIdT lid) {
+ _aw->remove(serialNum, lid, emptyCallback);
+ commit(serialNum);
}
- void remove(const LidVector &lidVector, SerialNum serialNum, bool immediateCommit = true) {
- _aw->remove(lidVector, serialNum, immediateCommit, emptyCallback);
+ void remove(const LidVector &lidVector, SerialNum serialNum) {
+ _aw->remove(lidVector, serialNum, emptyCallback);
+ commit(serialNum);
}
void commit(SerialNum serialNum) {
_aw->forceCommit(serialNum, emptyCallback);
}
void assertExecuteHistory(std::vector<uint32_t> expExecuteHistory) {
- EXPECT_EQ(expExecuteHistory, _attributeFieldWriter->getExecuteHistory());
+ auto includeCommit = expExecuteHistory;
+ includeCommit.insert(includeCommit.end(), expExecuteHistory.begin(), expExecuteHistory.end());
+ EXPECT_EQ(includeCommit, _attributeFieldWriter->getExecuteHistory());
}
SerialNum test_force_commit(AttributeVector &attr, SerialNum serialNum) {
commit(serialNum);
@@ -400,29 +405,29 @@ TEST_F(AttributeWriterTest, visibility_delay_is_honoured)
EXPECT_EQ(2u, a1->getNumDocs());
EXPECT_EQ(3u, a1->getStatus().getLastSyncToken());
AttributeWriter awDelayed(_mgr);
- awDelayed.put(4, *doc, 2, false, emptyCallback);
+ awDelayed.put(4, *doc, 2, emptyCallback);
EXPECT_EQ(3u, a1->getNumDocs());
EXPECT_EQ(3u, a1->getStatus().getLastSyncToken());
- awDelayed.put(5, *doc, 4, false, emptyCallback);
+ awDelayed.put(5, *doc, 4, emptyCallback);
EXPECT_EQ(5u, a1->getNumDocs());
EXPECT_EQ(3u, a1->getStatus().getLastSyncToken());
awDelayed.forceCommit(6, emptyCallback);
EXPECT_EQ(6u, a1->getStatus().getLastSyncToken());
AttributeWriter awDelayedShort(_mgr);
- awDelayedShort.put(7, *doc, 2, false, emptyCallback);
+ awDelayedShort.put(7, *doc, 2, emptyCallback);
EXPECT_EQ(6u, a1->getStatus().getLastSyncToken());
- awDelayedShort.put(8, *doc, 2, false, emptyCallback);
+ awDelayedShort.put(8, *doc, 2, emptyCallback);
awDelayedShort.forceCommit(8, emptyCallback);
EXPECT_EQ(8u, a1->getStatus().getLastSyncToken());
verifyAttributeContent(*a1, 2, "10");
awDelayed.put(9, *idb.startDocument("id:ns:searchdocument::1").startAttributeField("a1").addStr("11").endField().endDocument(),
- 2, false, emptyCallback);
+ 2, emptyCallback);
awDelayed.put(10, *idb.startDocument("id:ns:searchdocument::1").startAttributeField("a1").addStr("20").endField().endDocument(),
- 2, false, emptyCallback);
+ 2, emptyCallback);
awDelayed.put(11, *idb.startDocument("id:ns:searchdocument::1").startAttributeField("a1").addStr("30").endField().endDocument(),
- 2, false, emptyCallback);
+ 2, emptyCallback);
EXPECT_EQ(8u, a1->getStatus().getLastSyncToken());
verifyAttributeContent(*a1, 2, "10");
awDelayed.forceCommit(12, emptyCallback);
@@ -472,8 +477,7 @@ TEST_F(AttributeWriterTest, handles_update)
.addUpdate(ArithmeticValueUpdate(ArithmeticValueUpdate::Add, 10)));
DummyFieldUpdateCallback onUpdate;
- bool immediateCommit = true;
- update(2, upd, 1, immediateCommit, onUpdate);
+ update(2, upd, 1, onUpdate);
attribute::IntegerContent ibuf;
ibuf.fill(*a1, 1);
@@ -483,9 +487,9 @@ TEST_F(AttributeWriterTest, handles_update)
EXPECT_EQ(1u, ibuf.size());
EXPECT_EQ(30u, ibuf[0]);
- update(2, upd, 1, immediateCommit, onUpdate); // same sync token as previous
+ update(2, upd, 1, onUpdate); // same sync token as previous
try {
- update(1, upd, 1, immediateCommit, onUpdate); // lower sync token than previous
+ update(1, upd, 1, onUpdate); // lower sync token than previous
EXPECT_TRUE(true); // update is ignored
} catch (vespalib::IllegalStateException & e) {
LOG(info, "Got expected exception: '%s'", e.getMessage().c_str());
@@ -517,9 +521,8 @@ TEST_F(AttributeWriterTest, handles_predicate_update)
PredicateIndex &index = static_cast<PredicateAttribute &>(*a1).getIndex();
EXPECT_EQ(1u, index.getZeroConstraintDocs().size());
EXPECT_FALSE(index.getIntervalIndex().lookup(PredicateHash::hash64("foo=bar")).valid());
- bool immediateCommit = true;
DummyFieldUpdateCallback onUpdate;
- update(2, upd, 1, immediateCommit, onUpdate);
+ update(2, upd, 1, onUpdate);
EXPECT_EQ(0u, index.getZeroConstraintDocs().size());
EXPECT_TRUE(index.getIntervalIndex().lookup(PredicateHash::hash64("foo=bar")).valid());
}
@@ -712,9 +715,8 @@ TEST_F(AttributeWriterTest, handles_tensor_assign_update)
new_value = EngineOrFactory::get().copy(*new_tensor);
upd.addUpdate(FieldUpdate(upd.getType().getField("a1"))
.addUpdate(AssignValueUpdate(new_value)));
- bool immediateCommit = true;
DummyFieldUpdateCallback onUpdate;
- update(2, upd, 1, immediateCommit, onUpdate);
+ update(2, upd, 1, onUpdate);
EXPECT_EQ(2u, a1->getNumDocs());
EXPECT_TRUE(tensorAttribute != nullptr);
tensor2 = tensorAttribute->getTensor(1);
@@ -1078,7 +1080,7 @@ TEST_F(StructArrayWriterTest, update_with_doc_argument_updates_struct_field_attr
put(10, *doc, 1);
checkAttrs(1, 10, {11, 12});
doc = makeDoc(20, {21});
- update(11, *doc, 1, true);
+ update(11, *doc, 1);
checkAttrs(1, 10, {21});
}
@@ -1135,7 +1137,7 @@ TEST_F(StructMapWriterTest, update_with_doc_argument_updates_struct_field_attrib
put(10, *doc, 1);
checkAttrs(1, 10, {{1, 11}, {2, 12}});
doc = makeDoc(20, {{42, 21}});
- update(11, *doc, 1, true);
+ update(11, *doc, 1);
checkAttrs(1, 10, {{42, 21}});
}
diff --git a/searchcore/src/tests/proton/docsummary/docsummary.cpp b/searchcore/src/tests/proton/docsummary/docsummary.cpp
index 266a817d380..1d8b85864f6 100644
--- a/searchcore/src/tests/proton/docsummary/docsummary.cpp
+++ b/searchcore/src/tests/proton/docsummary/docsummary.cpp
@@ -97,7 +97,7 @@ public:
BuildContext(const Schema &schema)
: _dmk("summary"),
_bld(schema),
- _repo(new DocumentTypeRepo(_bld.getDocumentType())),
+ _repo(std::make_shared<DocumentTypeRepo>(_bld.getDocumentType())),
_summaryExecutor(4, 128 * 1024),
_noTlSyncer(),
_str(_summaryExecutor, "summary",
@@ -125,7 +125,7 @@ public:
}
FieldCacheRepo::UP createFieldCacheRepo(const ResultConfig &resConfig) const {
- return FieldCacheRepo::UP(new FieldCacheRepo(resConfig, _bld.getDocumentType()));
+ return std::make_unique<FieldCacheRepo>(resConfig, _bld.getDocumentType());
}
};
@@ -150,8 +150,7 @@ vespalib::string asVstring(const Inspector &value) {
}
void decode(const ResEntry *entry, vespalib::Slime &slime) {
- vespalib::Memory mem(entry->_dataval,
- entry->_datalen);
+ vespalib::Memory mem(entry->_dataval, entry->_datalen);
size_t decodeRes = BinaryFormat::decode(mem, slime);
ASSERT_EQUAL(decodeRes, mem.size);
}
@@ -216,14 +215,14 @@ public:
if (! FastOS_File::MakeDirectory((std::string("tmpdb/") + docTypeName).c_str())) {
LOG_ABORT("should not be reached");
}
- _ddb.reset(new DocumentDB("tmpdb", _configMgr.getConfig(), "tcp/localhost:9013", _queryLimiter, _clock,
- DocTypeName(docTypeName), makeBucketSpace(),
- *b->getProtonConfigSP(), *this, _summaryExecutor, _summaryExecutor,
- _tls, _dummy, _fileHeaderContext, ConfigStore::UP(new MemoryConfigStore),
- std::make_shared<vespalib::ThreadStackExecutor>(16, 128 * 1024), _hwInfo)),
+ _ddb = std::make_unique<DocumentDB>("tmpdb", _configMgr.getConfig(), "tcp/localhost:9013", _queryLimiter, _clock,
+ DocTypeName(docTypeName), makeBucketSpace(), *b->getProtonConfigSP(), *this,
+ _summaryExecutor, _summaryExecutor, _tls, _dummy, _fileHeaderContext,
+ std::make_unique<MemoryConfigStore>(),
+ std::make_shared<vespalib::ThreadStackExecutor>(16, 128 * 1024), _hwInfo),
_ddb->start();
_ddb->waitForOnlineState();
- _aw = AttributeWriter::UP(new AttributeWriter(_ddb->getReadySubDB()->getAttributeManager()));
+ _aw = std::make_unique<AttributeWriter>(_ddb->getReadySubDB()->getAttributeManager());
_sa = _ddb->getReadySubDB()->getSummaryAdapter();
}
~DBContext()
@@ -246,7 +245,8 @@ public:
Timestamp(0u), docSize, lid, 0u));
LOG_ASSERT(putRes.ok());
uint64_t serialNum = _ddb->getFeedHandler().incSerialNum();
- _aw->put(serialNum, doc, lid, true, std::shared_ptr<IDestructorCallback>());
+ _aw->put(serialNum, doc, lid, std::shared_ptr<IDestructorCallback>());
+ _aw->forceCommit(serialNum, std::shared_ptr<IDestructorCallback>());
_ddb->getReadySubDB()->getAttributeManager()->getAttributeFieldWriter().sync();
_sa->put(serialNum, lid, doc);
const GlobalId &gid = docId.getGlobalId();
@@ -259,10 +259,11 @@ public:
op->setSerialNum(serialNum);
op->setDbDocumentId(dbdId);
op->setPrevDbDocumentId(prevDbdId);
- _ddb->getWriteService().master().execute(vespalib::makeLambdaTask([this, op = std::move(op)]() {
- _ddb->getFeedHandler().appendOperation(*op, std::make_shared<search::IgnoreCallback>());
+ vespalib::Gate commitDone;
+ _ddb->getWriteService().master().execute(vespalib::makeLambdaTask([this, op = std::move(op), &commitDone]() {
+ _ddb->getFeedHandler().appendOperation(*op, std::make_shared<search::GateCallback>(commitDone));
}));
- _ddb->getWriteService().master().sync();
+ commitDone.await();
SearchView *sv(dynamic_cast<SearchView *>(_ddb->getReadySubDB()->getSearchView().get()));
if (sv != nullptr) {
// cf. FeedView::putAttributes()
diff --git a/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp b/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp
index 43f16e87986..754cf4ea15d 100644
--- a/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp
+++ b/searchcore/src/tests/proton/documentdb/document_subdbs/document_subdbs_test.cpp
@@ -769,18 +769,27 @@ struct DocumentHandler
}
void putDoc(PutOperation &op) {
IFeedView::SP feedView = _f._subDb.getFeedView();
- _f.runInMaster([&]() { feedView->preparePut(op);
- feedView->handlePut(FeedToken(), op); } );
+ _f.runInMaster([&]() {
+ feedView->preparePut(op);
+ feedView->handlePut(FeedToken(), op);
+ feedView->forceCommit(op.getSerialNum());
+ } );
}
void moveDoc(MoveOperation &op) {
IFeedView::SP feedView = _f._subDb.getFeedView();
- _f.runInMaster([&]() { feedView->handleMove(op, IDestructorCallback::SP()); } );
+ _f.runInMaster([&]() {
+ feedView->handleMove(op, IDestructorCallback::SP());
+ feedView->forceCommit(op.getSerialNum());
+ } );
}
void removeDoc(RemoveOperation &op)
{
IFeedView::SP feedView = _f._subDb.getFeedView();
- _f.runInMaster([&]() { feedView->prepareRemove(op);
- feedView->handleRemove(FeedToken(), op); } );
+ _f.runInMaster([&]() {
+ feedView->prepareRemove(op);
+ feedView->handleRemove(FeedToken(), op);
+ feedView->forceCommit(op.getSerialNum());
+ } );
}
void putDocs() {
PutOperation putOp = createPut(std::move(createDoc(1, 22, 33)), Timestamp(10), 10);
diff --git a/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp b/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp
index 1269804c98a..9bb8865707d 100644
--- a/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp
+++ b/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp
@@ -92,20 +92,17 @@ struct MyTracer
_os << ")";
}
- void tracePut(const vespalib::string &adapterType,
- SerialNum serialNum, uint32_t lid, bool immediateCommit) {
+ void tracePut(const vespalib::string &adapterType, SerialNum serialNum, uint32_t lid) {
Guard guard(_mutex);
addComma();
_os << "put(adapter=" << adapterType <<
- ",serialNum=" << serialNum << ",lid=" << lid << ",commit=" << immediateCommit << ")";
+ ",serialNum=" << serialNum << ",lid=" << lid << ")";
}
- void traceRemove(const vespalib::string &adapterType,
- SerialNum serialNum, uint32_t lid, bool immediateCommit) {
+ void traceRemove(const vespalib::string &adapterType, SerialNum serialNum, uint32_t lid) {
Guard guard(_mutex);
addComma();
- _os << "remove(adapter=" << adapterType <<
- ",serialNum=" << serialNum << ",lid=" << lid << ",commit=" << immediateCommit << ")";
+ _os << "remove(adapter=" << adapterType << ",serialNum=" << serialNum << ",lid=" << lid << ")";
}
void traceCommit(const vespalib::string &adapterType, SerialNum serialNum) {
@@ -151,12 +148,12 @@ struct MyIndexWriter : public test::MockIndexWriter
{}
void put(SerialNum serialNum, const document::Document &doc, const DocumentIdT lid) override {
(void) doc;
- _tracer.tracePut(indexAdapterTypeName, serialNum, lid, false);
+ _tracer.tracePut(indexAdapterTypeName, serialNum, lid);
}
void remove(SerialNum serialNum, const search::DocumentIdT lid) override {
LOG(info, "MyIndexAdapter::remove(): serialNum(%" PRIu64 "), docId(%u)", serialNum, lid);
_removes.push_back(lid);
- _tracer.traceRemove(indexAdapterTypeName, serialNum, lid, false);
+ _tracer.traceRemove(indexAdapterTypeName, serialNum, lid);
}
void commit(SerialNum serialNum, OnWriteDoneType) override {
++_commitCount;
@@ -335,35 +332,26 @@ struct MyAttributeWriter : public IAttributeWriter
AttrMap::const_iterator itr = _attrMap.find(attrName);
return ((itr == _attrMap.end()) ? nullptr : itr->second.get());
}
- void put(SerialNum serialNum, const document::Document &doc, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType) override {
+ void put(SerialNum serialNum, const document::Document &doc, DocumentIdT lid, OnWriteDoneType) override {
_putSerial = serialNum;
_putDocId = doc.getId();
_putLid = lid;
- _tracer.tracePut(attributeAdapterTypeName, serialNum, lid, immediateCommit);
- if (immediateCommit) {
- ++_commitCount;
- }
+ _tracer.tracePut(attributeAdapterTypeName, serialNum, lid);
}
- void remove(SerialNum serialNum, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType) override {
+ void remove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType) override {
_removeSerial = serialNum;
_removeLid = lid;
- _tracer.traceRemove(attributeAdapterTypeName, serialNum, lid, immediateCommit);
- if (immediateCommit) {
- ++_commitCount;
- }
+ _tracer.traceRemove(attributeAdapterTypeName, serialNum, lid);
}
- void remove(const LidVector & lidsToRemove, SerialNum serialNum,
- bool immediateCommit, OnWriteDoneType) override {
+ void remove(const LidVector & lidsToRemove, SerialNum serialNum, OnWriteDoneType) override {
for (uint32_t lid : lidsToRemove) {
LOG(info, "MyAttributeAdapter::remove(): serialNum(%" PRIu64 "), docId(%u)", serialNum, lid);
_removes.push_back(lid);
- _tracer.traceRemove(attributeAdapterTypeName, serialNum, lid, immediateCommit);
+ _tracer.traceRemove(attributeAdapterTypeName, serialNum, lid);
}
}
void update(SerialNum serialNum, const document::DocumentUpdate &upd,
- DocumentIdT lid, bool, OnWriteDoneType, IFieldUpdateCallback & onUpdate) override {
+ DocumentIdT lid, OnWriteDoneType, IFieldUpdateCallback & onUpdate) override {
_updateSerial = serialNum;
_updateDocId = upd.getId();
_updateLid = lid;
@@ -372,12 +360,10 @@ struct MyAttributeWriter : public IAttributeWriter
onUpdate.onUpdateField(fieldUpdate.getField().getName(), attr);
}
}
- void update(SerialNum serialNum, const document::Document &doc, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType) override {
+ void update(SerialNum serialNum, const document::Document &doc, DocumentIdT lid, OnWriteDoneType) override {
(void) serialNum;
(void) doc;
(void) lid;
- (void) immediateCommit;
}
void heartBeat(SerialNum) override { ++_heartBeatCount; }
void compactLidSpace(uint32_t wantedLidLimit, SerialNum ) override {
@@ -818,6 +804,7 @@ TEST_F("require that put() calls attribute adapter", SearchableFeedViewFixture)
DocumentContext dc = f.doc1();
EXPECT_EQUAL(0u, f._docIdLimit.get());
f.putAndWait(dc);
+ f.forceCommitAndWait();
EXPECT_EQUAL(1u, f.maw._putSerial);
EXPECT_EQUAL(DocumentId("id:ns:searchdocument::1"), f.maw._putDocId);
@@ -1184,26 +1171,6 @@ TEST_F("require that compactLidSpace() propagates to index writer",
EXPECT_EQUAL(2u, f.miw._wantedLidLimit);
}
-TEST_F("require that commit is called if visibility delay is 0",
- SearchableFeedViewFixture)
-{
- DocumentContext dc = f.doc1();
- f.putAndWait(dc);
- EXPECT_EQUAL(1u, f.miw._commitCount);
- EXPECT_EQUAL(1u, f.maw._commitCount);
- f.removeAndWait(dc);
- EXPECT_EQUAL(2u, f.miw._commitCount);
- EXPECT_EQUAL(2u, f.maw._commitCount);
- f.assertTrace("put(adapter=attribute,serialNum=1,lid=1,commit=1),"
- "put(adapter=index,serialNum=1,lid=1,commit=0),"
- "commit(adapter=index,serialNum=1),"
- "ack(Result(0, )),"
- "remove(adapter=attribute,serialNum=2,lid=1,commit=1),"
- "remove(adapter=index,serialNum=2,lid=1,commit=0),"
- "commit(adapter=index,serialNum=2),"
- "ack(Result(0, ))");
-}
-
const vespalib::duration LONG_DELAY = 60s;
const vespalib::duration SHORT_DELAY = 500ms;
@@ -1219,11 +1186,11 @@ TEST_F("require that commit is not called when inside a commit interval",
EXPECT_EQUAL(0u, f.miw._commitCount);
EXPECT_EQUAL(0u, f.maw._commitCount);
EXPECT_EQUAL(0u, f._docIdLimit.get());
- f.assertTrace("put(adapter=attribute,serialNum=1,lid=1,commit=0),"
- "put(adapter=index,serialNum=1,lid=1,commit=0),"
+ f.assertTrace("put(adapter=attribute,serialNum=1,lid=1),"
+ "put(adapter=index,serialNum=1,lid=1),"
"ack(Result(0, )),"
- "remove(adapter=attribute,serialNum=2,lid=1,commit=0),"
- "remove(adapter=index,serialNum=2,lid=1,commit=0),"
+ "remove(adapter=attribute,serialNum=2,lid=1),"
+ "remove(adapter=index,serialNum=2,lid=1),"
"ack(Result(0, ))");
f.forceCommitAndWait();
}
@@ -1242,11 +1209,11 @@ TEST_F("require that commit is not implicitly called",
EXPECT_EQUAL(0u, f.miw._commitCount);
EXPECT_EQUAL(0u, f.maw._commitCount);
EXPECT_EQUAL(0u, f._docIdLimit.get());
- f.assertTrace("put(adapter=attribute,serialNum=1,lid=1,commit=0),"
- "put(adapter=index,serialNum=1,lid=1,commit=0),"
+ f.assertTrace("put(adapter=attribute,serialNum=1,lid=1),"
+ "put(adapter=index,serialNum=1,lid=1),"
"ack(Result(0, )),"
- "remove(adapter=attribute,serialNum=2,lid=1,commit=0),"
- "remove(adapter=index,serialNum=2,lid=1,commit=0),"
+ "remove(adapter=attribute,serialNum=2,lid=1),"
+ "remove(adapter=index,serialNum=2,lid=1),"
"ack(Result(0, ))");
f.forceCommitAndWait();
}
@@ -1263,8 +1230,8 @@ TEST_F("require that forceCommit updates docid limit",
EXPECT_EQUAL(1u, f.miw._commitCount);
EXPECT_EQUAL(1u, f.maw._commitCount);
EXPECT_EQUAL(2u, f._docIdLimit.get());
- f.assertTrace("put(adapter=attribute,serialNum=1,lid=1,commit=0),"
- "put(adapter=index,serialNum=1,lid=1,commit=0),"
+ f.assertTrace("put(adapter=attribute,serialNum=1,lid=1),"
+ "put(adapter=index,serialNum=1,lid=1),"
"ack(Result(0, )),"
"commit(adapter=attribute,serialNum=1),"
"commit(adapter=index,serialNum=1)");
diff --git a/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp b/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp
index e6e71d51e47..3a75f8cd494 100644
--- a/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp
+++ b/searchcore/src/tests/proton/documentdb/storeonlyfeedview/storeonlyfeedview_test.cpp
@@ -105,12 +105,12 @@ struct MyMinimalFeedView : public MyMinimalFeedViewBase, public StoreOnlyFeedVie
outstandingMoveOps(outstandingMoveOps_)
{
}
- void removeAttributes(SerialNum s, const LidVector &l, bool immediateCommit, OnWriteDoneType onWriteDone) override {
- StoreOnlyFeedView::removeAttributes(s, l, immediateCommit, onWriteDone);
+ void removeAttributes(SerialNum s, const LidVector &l, OnWriteDoneType onWriteDone) override {
+ StoreOnlyFeedView::removeAttributes(s, l, onWriteDone);
++removeMultiAttributesCount;
}
- void removeIndexedFields(SerialNum s, const LidVector &l, bool immediateCommit, OnWriteDoneType onWriteDone) override {
- StoreOnlyFeedView::removeIndexedFields(s, l, immediateCommit, onWriteDone);
+ void removeIndexedFields(SerialNum s, const LidVector &l, OnWriteDoneType onWriteDone) override {
+ StoreOnlyFeedView::removeIndexedFields(s, l, onWriteDone);
++removeMultiIndexFieldsCount;
}
void heartBeatIndexedFields(SerialNum s) override {
@@ -145,23 +145,23 @@ struct MoveOperationFeedView : public MyMinimalFeedView {
removeIndexFieldsCount(0),
onWriteDoneContexts()
{}
- void putAttributes(SerialNum, search::DocumentIdT, const document::Document &, bool, OnPutDoneType onWriteDone) override {
+ void putAttributes(SerialNum, search::DocumentIdT, const document::Document &, OnPutDoneType onWriteDone) override {
++putAttributesCount;
EXPECT_EQUAL(1, outstandingMoveOps);
onWriteDoneContexts.push_back(onWriteDone);
}
void putIndexedFields(SerialNum, search::DocumentIdT, const document::Document::SP &,
- bool, OnOperationDoneType onWriteDone) override {
+ OnOperationDoneType onWriteDone) override {
++putIndexFieldsCount;
EXPECT_EQUAL(1, outstandingMoveOps);
onWriteDoneContexts.push_back(onWriteDone);
}
- void removeAttributes(SerialNum, search::DocumentIdT, bool, OnRemoveDoneType onWriteDone) override {
+ void removeAttributes(SerialNum, search::DocumentIdT, OnRemoveDoneType onWriteDone) override {
++removeAttributesCount;
EXPECT_EQUAL(1, outstandingMoveOps);
onWriteDoneContexts.push_back(onWriteDone);
}
- void removeIndexedFields(SerialNum, search::DocumentIdT, bool, OnRemoveDoneType onWriteDone) override {
+ void removeIndexedFields(SerialNum, search::DocumentIdT, OnRemoveDoneType onWriteDone) override {
++removeIndexFieldsCount;
EXPECT_EQUAL(1, outstandingMoveOps);
onWriteDoneContexts.push_back(onWriteDone);
diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_populator.cpp b/searchcore/src/vespa/searchcore/proton/attribute/attribute_populator.cpp
index 33a5776cb8a..af7bae32b11 100644
--- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_populator.cpp
+++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_populator.cpp
@@ -3,6 +3,8 @@
#include "attribute_populator.h"
#include <vespa/searchcore/proton/common/eventlogger.h>
#include <vespa/searchlib/common/idestructorcallback.h>
+#include <vespa/searchlib/common/gatecallback.h>
+#include <vespa/vespalib/util/gate.h>
#include <vespa/searchlib/attribute/attributevector.h>
#include <vespa/log/log.h>
@@ -73,8 +75,10 @@ void
AttributePopulator::handleExisting(uint32_t lid, const std::shared_ptr<document::Document> &doc)
{
search::SerialNum serialNum(nextSerialNum());
- auto populateDoneContext = std::make_shared<PopulateDoneContext>(doc);
- _writer.put(serialNum, *doc, lid, true, populateDoneContext);
+ _writer.put(serialNum, *doc, lid, std::make_shared<PopulateDoneContext>(doc));
+ vespalib::Gate gate;
+ _writer.forceCommit(serialNum, std::make_shared<search::GateCallback>(gate));
+ gate.await();
}
void
diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp
index bf32b679d76..2b859c17931 100644
--- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp
+++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.cpp
@@ -127,8 +127,7 @@ ensureLidSpace(SerialNum serialNum, DocumentIdT lid, AttributeVector &attr)
void
applyPutToAttribute(SerialNum serialNum, const FieldValue::UP &fieldValue, DocumentIdT lid,
- bool immediateCommit, AttributeVector &attr,
- AttributeWriter::OnWriteDoneType)
+ AttributeVector &attr, AttributeWriter::OnWriteDoneType)
{
ensureLidSpace(serialNum, lid, attr);
if (fieldValue.get()) {
@@ -136,9 +135,6 @@ applyPutToAttribute(SerialNum serialNum, const FieldValue::UP &fieldValue, Docum
} else {
attr.clearDoc(lid);
}
- if (immediateCommit) {
- attr.commit(serialNum, serialNum);
- }
}
void
@@ -147,7 +143,6 @@ complete_put_to_attribute(SerialNum serial_num,
AttributeVector& attr,
const FieldValue::SP& field_value,
std::future<std::unique_ptr<PrepareResult>>& result_future,
- bool immediate_commit,
AttributeWriter::OnWriteDoneType)
{
ensureLidSpace(serial_num, docid, attr);
@@ -157,20 +152,14 @@ complete_put_to_attribute(SerialNum serial_num,
} else {
attr.clearDoc(docid);
}
- if (immediate_commit) {
- attr.commit(serial_num, serial_num);
- }
}
void
-applyRemoveToAttribute(SerialNum serialNum, DocumentIdT lid, bool immediateCommit,
+applyRemoveToAttribute(SerialNum serialNum, DocumentIdT lid,
AttributeVector &attr, AttributeWriter::OnWriteDoneType)
{
ensureLidSpace(serialNum, lid, attr);
attr.clearDoc(lid);
- if (immediateCommit) {
- attr.commit(serialNum, serialNum);
- }
}
void
@@ -182,15 +171,6 @@ applyUpdateToAttribute(SerialNum serialNum, const FieldUpdate &fieldUpd,
}
void
-applyUpdateToAttributeAndCommit(SerialNum serialNum, const FieldUpdate &fieldUpd,
- DocumentIdT lid, AttributeVector &attr)
-{
- ensureLidSpace(serialNum, lid, attr);
- AttributeUpdater::handleUpdate(attr, lid, fieldUpd);
- attr.commit(serialNum, serialNum);
-}
-
-void
applyReplayDone(uint32_t docIdLimit, AttributeVector &attr)
{
AttributeManager::padAttribute(attr, docIdLimit);
@@ -240,30 +220,22 @@ using AttrUpdates = std::vector<std::pair<AttributeVector *, const FieldUpdate *
struct BatchUpdateTask : public vespalib::Executor::Task {
- BatchUpdateTask(SerialNum serialNum, DocumentIdT lid, bool immediateCommit)
+ BatchUpdateTask(SerialNum serialNum, DocumentIdT lid)
: vespalib::Executor::Task(),
_serialNum(serialNum),
_lid(lid),
- _immediateCommit(immediateCommit),
_onWriteDone()
{ }
~BatchUpdateTask() override;
void run() override {
- if (_immediateCommit) {
- for (const auto & update : _updates) {
- applyUpdateToAttributeAndCommit(_serialNum, *update.second, _lid, *update.first);
- }
- } else {
- for (const auto & update : _updates) {
- applyUpdateToAttribute(_serialNum, *update.second, _lid, *update.first);
- }
+ for (const auto & update : _updates) {
+ applyUpdateToAttribute(_serialNum, *update.second, _lid, *update.first);
}
}
SerialNum _serialNum;
DocumentIdT _lid;
- bool _immediateCommit;
AttrUpdates _updates;
search::IDestructorCallback::SP _onWriteDone;
};
@@ -310,22 +282,20 @@ class PutTask : public vespalib::Executor::Task
const AttributeWriter::WriteContext &_wc;
const SerialNum _serialNum;
const uint32_t _lid;
- const bool _immediateCommit;
const bool _allAttributes;
std::remove_reference_t<AttributeWriter::OnWriteDoneType> _onWriteDone;
std::shared_ptr<DocumentFieldExtractor> _fieldExtractor;
std::vector<FieldValue::UP> _fieldValues;
public:
- PutTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, std::shared_ptr<DocumentFieldExtractor> fieldExtractor, uint32_t lid, bool immediateCommit, bool allAttributes, AttributeWriter::OnWriteDoneType onWriteDone);
+ PutTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, std::shared_ptr<DocumentFieldExtractor> fieldExtractor, uint32_t lid, bool allAttributes, AttributeWriter::OnWriteDoneType onWriteDone);
~PutTask() override;
void run() override;
};
-PutTask::PutTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, std::shared_ptr<DocumentFieldExtractor> fieldExtractor, uint32_t lid, bool immediateCommit, bool allAttributes, AttributeWriter::OnWriteDoneType onWriteDone)
+PutTask::PutTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, std::shared_ptr<DocumentFieldExtractor> fieldExtractor, uint32_t lid, bool allAttributes, AttributeWriter::OnWriteDoneType onWriteDone)
: _wc(wc),
_serialNum(serialNum),
_lid(lid),
- _immediateCommit(immediateCommit),
_allAttributes(allAttributes),
_onWriteDone(onWriteDone),
_fieldExtractor(std::move(fieldExtractor)),
@@ -352,7 +322,7 @@ PutTask::run()
if (_allAttributes || field.isStructFieldAttribute()) {
AttributeVector &attr = field.getAttribute();
if (attr.getStatus().getLastSyncToken() < _serialNum) {
- applyPutToAttribute(_serialNum, _fieldValues[fieldId], _lid, _immediateCommit, attr, _onWriteDone);
+ applyPutToAttribute(_serialNum, _fieldValues[fieldId], _lid, attr, _onWriteDone);
}
++fieldId;
}
@@ -418,26 +388,22 @@ private:
AttributeVector& _attr;
FieldValue::SP _field_value;
std::future<std::unique_ptr<PrepareResult>> _result_future;
- const bool _immediate_commit;
std::remove_reference_t<AttributeWriter::OnWriteDoneType> _on_write_done;
public:
CompletePutTask(PreparePutTask& prepare_task,
- bool immediate_commit,
AttributeWriter::OnWriteDoneType on_write_done);
~CompletePutTask() override;
void run() override;
};
CompletePutTask::CompletePutTask(PreparePutTask& prepare_task,
- bool immediate_commit,
AttributeWriter::OnWriteDoneType on_write_done)
: _serial_num(prepare_task.serial_num()),
_docid(prepare_task.docid()),
_attr(prepare_task.attr()),
_field_value(prepare_task.field_value()),
_result_future(prepare_task.result_future()),
- _immediate_commit(immediate_commit),
_on_write_done(on_write_done)
{
}
@@ -448,8 +414,7 @@ void
CompletePutTask::run()
{
if (_attr.getStatus().getLastSyncToken() < _serial_num) {
- complete_put_to_attribute(_serial_num, _docid, _attr, _field_value, _result_future,
- _immediate_commit, _on_write_done);
+ complete_put_to_attribute(_serial_num, _docid, _attr, _field_value, _result_future, _on_write_done);
}
}
@@ -458,19 +423,17 @@ class RemoveTask : public vespalib::Executor::Task
const AttributeWriter::WriteContext &_wc;
const SerialNum _serialNum;
const uint32_t _lid;
- const bool _immediateCommit;
std::remove_reference_t<AttributeWriter::OnWriteDoneType> _onWriteDone;
public:
- RemoveTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, uint32_t lid, bool immediateCommit, AttributeWriter::OnWriteDoneType onWriteDone);
+ RemoveTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, uint32_t lid, AttributeWriter::OnWriteDoneType onWriteDone);
~RemoveTask() override;
void run() override;
};
-RemoveTask::RemoveTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, uint32_t lid, bool immediateCommit, AttributeWriter::OnWriteDoneType onWriteDone)
+RemoveTask::RemoveTask(const AttributeWriter::WriteContext &wc, SerialNum serialNum, uint32_t lid, AttributeWriter::OnWriteDoneType onWriteDone)
: _wc(wc),
_serialNum(serialNum),
_lid(lid),
- _immediateCommit(immediateCommit),
_onWriteDone(onWriteDone)
{
}
@@ -485,7 +448,7 @@ RemoveTask::run()
AttributeVector &attr = field.getAttribute();
// Must use <= due to how move operations are handled
if (attr.getStatus().getLastSyncToken() <= _serialNum) {
- applyRemoveToAttribute(_serialNum, _lid, _immediateCommit, attr, _onWriteDone);
+ applyRemoveToAttribute(_serialNum, _lid, attr, _onWriteDone);
}
}
}
@@ -496,18 +459,15 @@ private:
const AttributeWriter::WriteContext &_writeCtx;
const SerialNum _serialNum;
const LidVector _lidsToRemove;
- const bool _immediateCommit;
std::remove_reference_t<AttributeWriter::OnWriteDoneType> _onWriteDone;
public:
BatchRemoveTask(const AttributeWriter::WriteContext &writeCtx,
SerialNum serialNum,
const LidVector &lidsToRemove,
- bool immediateCommit,
AttributeWriter::OnWriteDoneType onWriteDone)
: _writeCtx(writeCtx),
_serialNum(serialNum),
_lidsToRemove(lidsToRemove),
- _immediateCommit(immediateCommit),
_onWriteDone(onWriteDone)
{}
~BatchRemoveTask() override;
@@ -516,10 +476,7 @@ public:
auto &attr = field.getAttribute();
if (attr.getStatus().getLastSyncToken() < _serialNum) {
for (auto lidToRemove : _lidsToRemove) {
- applyRemoveToAttribute(_serialNum, lidToRemove, false, attr, _onWriteDone);
- }
- if (_immediateCommit) {
- attr.commit(_serialNum, _serialNum);
+ applyRemoveToAttribute(_serialNum, lidToRemove, attr, _onWriteDone);
}
}
}
@@ -604,7 +561,7 @@ AttributeWriter::buildFieldPaths(const DocumentType & docType, const DataType *d
void
AttributeWriter::internalPut(SerialNum serialNum, const Document &doc, DocumentIdT lid,
- bool immediateCommit, bool allAttributes, OnWriteDoneType onWriteDone)
+ bool allAttributes, OnWriteDoneType onWriteDone)
{
const DataType *dataType(doc.getDataType());
if (_dataType != dataType) {
@@ -615,13 +572,12 @@ AttributeWriter::internalPut(SerialNum serialNum, const Document &doc, DocumentI
if (wc.use_two_phase_put()) {
assert(wc.getFields().size() == 1);
auto prepare_task = std::make_unique<PreparePutTask>(serialNum, lid, wc.getFields()[0], extractor);
- auto complete_task = std::make_unique<CompletePutTask>(*prepare_task, immediateCommit, onWriteDone);
+ auto complete_task = std::make_unique<CompletePutTask>(*prepare_task, onWriteDone);
_shared_executor.execute(std::move(prepare_task));
_attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(complete_task));
} else {
if (allAttributes || wc.hasStructFieldAttribute()) {
- auto putTask = std::make_unique<PutTask>(wc, serialNum, extractor, lid, immediateCommit, allAttributes,
- onWriteDone);
+ auto putTask = std::make_unique<PutTask>(wc, serialNum, extractor, lid, allAttributes, onWriteDone);
_attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(putTask));
}
}
@@ -629,11 +585,10 @@ AttributeWriter::internalPut(SerialNum serialNum, const Document &doc, DocumentI
}
void
-AttributeWriter::internalRemove(SerialNum serialNum, DocumentIdT lid, bool immediateCommit,
- OnWriteDoneType onWriteDone)
+AttributeWriter::internalRemove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType onWriteDone)
{
for (const auto &wc : _writeContexts) {
- auto removeTask = std::make_unique<RemoveTask>(wc, serialNum, lid, immediateCommit, onWriteDone);
+ auto removeTask = std::make_unique<RemoveTask>(wc, serialNum, lid, onWriteDone);
_attributeFieldWriter.executeTask(wc.getExecutorId(), std::move(removeTask));
}
}
@@ -678,50 +633,46 @@ AttributeWriter::getWritableAttribute(const vespalib::string &name) const
}
void
-AttributeWriter::put(SerialNum serialNum, const Document &doc, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone)
+AttributeWriter::put(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone)
{
LOG(spam, "Handle put: serial(%" PRIu64 "), docId(%s), lid(%u), document(%s)",
serialNum, doc.getId().toString().c_str(), lid, doc.toString(true).c_str());
- internalPut(serialNum, doc, lid, immediateCommit, true, onWriteDone);
+ internalPut(serialNum, doc, lid, true, onWriteDone);
}
void
-AttributeWriter::update(SerialNum serialNum, const Document &doc, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone)
+AttributeWriter::update(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone)
{
LOG(spam, "Handle update: serial(%" PRIu64 "), docId(%s), lid(%u), document(%s)",
serialNum, doc.getId().toString().c_str(), lid, doc.toString(true).c_str());
- internalPut(serialNum, doc, lid, immediateCommit, false, onWriteDone);
+ internalPut(serialNum, doc, lid, false, onWriteDone);
}
void
-AttributeWriter::remove(SerialNum serialNum, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone)
+AttributeWriter::remove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType onWriteDone)
{
- internalRemove(serialNum, lid, immediateCommit, onWriteDone);
+ internalRemove(serialNum, lid, onWriteDone);
}
void
-AttributeWriter::remove(const LidVector &lidsToRemove, SerialNum serialNum,
- bool immediateCommit, OnWriteDoneType onWriteDone)
+AttributeWriter::remove(const LidVector &lidsToRemove, SerialNum serialNum, OnWriteDoneType onWriteDone)
{
for (const auto &writeCtx : _writeContexts) {
- auto removeTask = std::make_unique<BatchRemoveTask>(writeCtx, serialNum, lidsToRemove, immediateCommit, onWriteDone);
+ auto removeTask = std::make_unique<BatchRemoveTask>(writeCtx, serialNum, lidsToRemove, onWriteDone);
_attributeFieldWriter.executeTask(writeCtx.getExecutorId(), std::move(removeTask));
}
}
void
AttributeWriter::update(SerialNum serialNum, const DocumentUpdate &upd, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate)
+ OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate)
{
LOG(debug, "Inspecting update for document %d.", lid);
std::vector<std::unique_ptr<BatchUpdateTask>> args;
uint32_t numExecutors = _attributeFieldWriter.getNumExecutors();
args.reserve(numExecutors);
for (uint32_t i(0); i < numExecutors; i++) {
- args.emplace_back(std::make_unique<BatchUpdateTask>(serialNum, lid, immediateCommit));
+ args.emplace_back(std::make_unique<BatchUpdateTask>(serialNum, lid));
args.back()->_updates.reserve((2*upd.getUpdates().size())/numExecutors);
}
diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h
index 9e9e8910669..f63a2c6efba 100644
--- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h
+++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_writer.h
@@ -79,13 +79,12 @@ private:
void setupAttriuteMapping();
void buildFieldPaths(const DocumentType &docType, const DataType *dataType);
void internalPut(SerialNum serialNum, const Document &doc, DocumentIdT lid,
- bool immediateCommit, bool allAttributes, OnWriteDoneType onWriteDone);
- void internalRemove(SerialNum serialNum, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone);
+ bool allAttributes, OnWriteDoneType onWriteDone);
+ void internalRemove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType onWriteDone);
public:
AttributeWriter(proton::IAttributeManager::SP mgr);
- ~AttributeWriter();
+ ~AttributeWriter() override;
/* Only for in tests that add attributes after AttributeWriter construction. */
@@ -94,16 +93,12 @@ public:
*/
std::vector<search::AttributeVector *> getWritableAttributes() const override;
search::AttributeVector *getWritableAttribute(const vespalib::string &name) const override;
- void put(SerialNum serialNum, const Document &doc, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone) override;
- void remove(SerialNum serialNum, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone) override;
- void remove(const LidVector &lidVector, SerialNum serialNum,
- bool immediateCommit, OnWriteDoneType onWriteDone) override;
+ void put(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone) override;
+ void remove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType onWriteDone) override;
+ void remove(const LidVector &lidVector, SerialNum serialNum, OnWriteDoneType onWriteDone) override;
void update(SerialNum serialNum, const DocumentUpdate &upd, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate) override;
- void update(SerialNum serialNum, const Document &doc, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone) override;
+ OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate) override;
+ void update(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone) override;
void heartBeat(SerialNum serialNum) override;
void compactLidSpace(uint32_t wantedLidLimit, SerialNum serialNum) override;
const proton::IAttributeManager::SP &getAttributeManager() const override {
diff --git a/searchcore/src/vespa/searchcore/proton/attribute/filter_attribute_manager.cpp b/searchcore/src/vespa/searchcore/proton/attribute/filter_attribute_manager.cpp
index 3b1269b031c..ffdfdbc4332 100644
--- a/searchcore/src/vespa/searchcore/proton/attribute/filter_attribute_manager.cpp
+++ b/searchcore/src/vespa/searchcore/proton/attribute/filter_attribute_manager.cpp
@@ -233,7 +233,7 @@ FilterAttributeManager::setImportedAttributes(std::unique_ptr<ImportedAttributes
const ImportedAttributesRepo *
FilterAttributeManager::getImportedAttributes() const
{
- throw vespalib::IllegalArgumentException("Not implemented");
+ return nullptr;
}
std::shared_ptr<search::attribute::ReadableAttributeVector>
diff --git a/searchcore/src/vespa/searchcore/proton/attribute/i_attribute_writer.h b/searchcore/src/vespa/searchcore/proton/attribute/i_attribute_writer.h
index 99b5728fd3a..789a8077cba 100644
--- a/searchcore/src/vespa/searchcore/proton/attribute/i_attribute_writer.h
+++ b/searchcore/src/vespa/searchcore/proton/attribute/i_attribute_writer.h
@@ -33,27 +33,23 @@ public:
typedef document::Document Document;
using OnWriteDoneType = const std::shared_ptr<search::IDestructorCallback> &;
- virtual ~IAttributeWriter() {}
+ virtual ~IAttributeWriter() = default;
virtual std::vector<search::AttributeVector *> getWritableAttributes() const = 0;
virtual search::AttributeVector *getWritableAttribute(const vespalib::string &attrName) const = 0;
- virtual void put(SerialNum serialNum, const Document &doc, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone) = 0;
- virtual void remove(SerialNum serialNum, DocumentIdT lid, bool immediateCommit,
- OnWriteDoneType onWriteDone) = 0;
- virtual void remove(const LidVector &lidVector, SerialNum serialNum,
- bool immediateCommit, OnWriteDoneType onWriteDone) = 0;
+ virtual void put(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone) = 0;
+ virtual void remove(SerialNum serialNum, DocumentIdT lid, OnWriteDoneType onWriteDone) = 0;
+ virtual void remove(const LidVector &lidVector, SerialNum serialNum, OnWriteDoneType onWriteDone) = 0;
/**
* Update the underlying attributes based on the content of the given DocumentUpdate.
* The OnWriteDoneType instance should ensure the lifetime of the given DocumentUpdate instance.
*/
virtual void update(SerialNum serialNum, const DocumentUpdate &upd, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate) = 0;
+ OnWriteDoneType onWriteDone, IFieldUpdateCallback & onUpdate) = 0;
/*
* Update the underlying struct field attributes based on updated document.
*/
- virtual void update(SerialNum serialNum, const Document &doc, DocumentIdT lid,
- bool immediateCommit, OnWriteDoneType onWriteDone) = 0;
+ virtual void update(SerialNum serialNum, const Document &doc, DocumentIdT lid, OnWriteDoneType onWriteDone) = 0;
virtual void heartBeat(SerialNum serialNum) = 0;
/**
* Compact the lid space of the underlying attribute vectors.
diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.cpp b/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.cpp
index e547f3556be..7fab995dfb9 100644
--- a/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.cpp
+++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.cpp
@@ -31,7 +31,6 @@ DocumentMetaStoreAttribute::DocumentMetaStoreAttribute(const vespalib::string &n
{ }
-DocumentMetaStoreAttribute::~DocumentMetaStoreAttribute()
-{ }
+DocumentMetaStoreAttribute::~DocumentMetaStoreAttribute() = default;
}
diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.h b/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.h
index 5b286907fb8..721aa8fe126 100644
--- a/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.h
+++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreattribute.h
@@ -14,22 +14,20 @@ namespace proton {
class DocumentMetaStoreAttribute : public search::NotImplementedAttribute
{
protected:
- virtual void notImplemented() const override __attribute__((noinline));
+ void notImplemented() const override __attribute__((noinline));
public:
DocumentMetaStoreAttribute(const vespalib::string &name=getFixedName());
- virtual ~DocumentMetaStoreAttribute();
+ ~DocumentMetaStoreAttribute() override;
static const vespalib::string &getFixedName();
// Implements IAttributeVector
- virtual size_t
- getFixedWidth() const override
- {
+ size_t getFixedWidth() const override {
return document::GlobalId::LENGTH;
}
- virtual void onCommit() override {}
+ void onCommit() override {}
};
}
diff --git a/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.cpp b/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.cpp
index eec58ed53dc..52b4d869ce8 100644
--- a/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.cpp
@@ -20,46 +20,39 @@ namespace proton {
* Otherwise we can drop it and ack the operation right away.
*/
void
-FastAccessFeedView::putAttributes(SerialNum serialNum, search::DocumentIdT lid, const Document &doc,
- bool immediateCommit, OnPutDoneType onWriteDone)
+FastAccessFeedView::putAttributes(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, OnPutDoneType onWriteDone)
{
- _attributeWriter->put(serialNum, doc, lid, immediateCommit, onWriteDone);
- if (immediateCommit && onWriteDone) {
- onWriteDone->registerPutLid(&_docIdLimit);
- }
+ _attributeWriter->put(serialNum, doc, lid, onWriteDone);
}
void
FastAccessFeedView::updateAttributes(SerialNum serialNum, search::DocumentIdT lid, const DocumentUpdate &upd,
- bool immediateCommit, OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate)
+ OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate)
{
- _attributeWriter->update(serialNum, upd, lid, immediateCommit, onWriteDone, onUpdate);
+ _attributeWriter->update(serialNum, upd, lid, onWriteDone, onUpdate);
}
void
-FastAccessFeedView::updateAttributes(SerialNum serialNum, Lid lid, FutureDoc futureDoc,
- bool immediateCommit, OnOperationDoneType onWriteDone)
+FastAccessFeedView::updateAttributes(SerialNum serialNum, Lid lid, FutureDoc futureDoc, OnOperationDoneType onWriteDone)
{
if (_attributeWriter->hasStructFieldAttribute()) {
const std::unique_ptr<const Document> & doc = futureDoc.get();
if (doc) {
- _attributeWriter->update(serialNum, *doc, lid, immediateCommit, onWriteDone);
+ _attributeWriter->update(serialNum, *doc, lid, onWriteDone);
}
}
}
void
-FastAccessFeedView::removeAttributes(SerialNum serialNum, search::DocumentIdT lid,
- bool immediateCommit, OnRemoveDoneType onWriteDone)
+FastAccessFeedView::removeAttributes(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone)
{
- _attributeWriter->remove(serialNum, lid, immediateCommit, onWriteDone);
+ _attributeWriter->remove(serialNum, lid, onWriteDone);
}
void
-FastAccessFeedView::removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove,
- bool immediateCommit, OnWriteDoneType onWriteDone)
+FastAccessFeedView::removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone)
{
- _attributeWriter->remove(lidsToRemove, serialNum, immediateCommit, onWriteDone);
+ _attributeWriter->remove(lidsToRemove, serialNum, onWriteDone);
}
void
diff --git a/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.h b/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.h
index 08f11869b08..e0823be3e43 100644
--- a/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.h
+++ b/searchcore/src/vespa/searchcore/proton/server/fast_access_feed_view.h
@@ -36,18 +36,14 @@ private:
const IAttributeWriter::SP _attributeWriter;
DocIdLimit &_docIdLimit;
- void putAttributes(SerialNum serialNum, search::DocumentIdT lid, const Document &doc,
- bool immediateCommit, OnPutDoneType onWriteDone) override;
+ void putAttributes(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, OnPutDoneType onWriteDone) override;
void updateAttributes(SerialNum serialNum, search::DocumentIdT lid, const document::DocumentUpdate &upd,
- bool immediateCommit, OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate) override;
- void updateAttributes(SerialNum serialNum, Lid lid, FutureDoc doc,
- bool immediateCommit, OnOperationDoneType onWriteDone) override;
- void removeAttributes(SerialNum serialNum, search::DocumentIdT lid,
- bool immediateCommit, OnRemoveDoneType onWriteDone) override;
-
- void removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove,
- bool immediateCommit, OnWriteDoneType onWriteDone) override;
+ OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate) override;
+ void updateAttributes(SerialNum serialNum, Lid lid, FutureDoc doc, OnOperationDoneType onWriteDone) override;
+ void removeAttributes(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone) override;
+
+ void removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone) override;
void heartBeatAttributes(SerialNum serialNum) override;
diff --git a/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.cpp b/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.cpp
index 360cac6e2ee..ebef7b4b6d4 100644
--- a/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.cpp
@@ -59,46 +59,41 @@ SearchableFeedView::sync()
void
SearchableFeedView::putIndexedFields(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &newDoc,
- bool immediateCommit, OnOperationDoneType onWriteDone)
+ OnOperationDoneType onWriteDone)
{
if (!hasIndexedFields()) {
return;
}
_writeService.index().execute(
makeLambdaTask([=] {
- performIndexPut(serialNum, lid, newDoc, immediateCommit, onWriteDone);
+ performIndexPut(serialNum, lid, newDoc, onWriteDone);
}));
}
void
-SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const Document &doc,
- bool immediateCommit, OnOperationDoneType onWriteDone)
+SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, OnOperationDoneType onWriteDone)
{
+ (void) onWriteDone;
assert(_writeService.index().isCurrentThread());
VLOG(getDebugLevel(lid, doc.getId()),
"database(%s): performIndexPut: serialNum(%" PRIu64 "), docId(%s), lid(%d)",
_params._docTypeName.toString().c_str(), serialNum, doc.getId().toString().c_str(), lid);
_indexWriter->put(serialNum, doc, lid);
- if (immediateCommit) {
- _indexWriter->commit(serialNum, onWriteDone);
- }
}
void
-SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &doc,
- bool immediateCommit, OnOperationDoneType onWriteDone)
+SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &doc, OnOperationDoneType onWriteDone)
{
- performIndexPut(serialNum, lid, *doc, immediateCommit, onWriteDone);
+ performIndexPut(serialNum, lid, *doc, onWriteDone);
}
void
-SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, FutureDoc futureDoc,
- bool immediateCommit, OnOperationDoneType onWriteDone)
+SearchableFeedView::performIndexPut(SerialNum serialNum, search::DocumentIdT lid, FutureDoc futureDoc, OnOperationDoneType onWriteDone)
{
const auto &doc = futureDoc.get();
if (doc) {
- performIndexPut(serialNum, lid, *doc, immediateCommit, onWriteDone);
+ performIndexPut(serialNum, lid, *doc, onWriteDone);
}
}
@@ -115,49 +110,44 @@ SearchableFeedView::performIndexHeartBeat(SerialNum serialNum)
}
void
-SearchableFeedView::updateIndexedFields(SerialNum serialNum, search::DocumentIdT lid, FutureDoc futureDoc,
- bool immediateCommit, OnOperationDoneType onWriteDone)
+SearchableFeedView::updateIndexedFields(SerialNum serialNum, search::DocumentIdT lid, FutureDoc futureDoc, OnOperationDoneType onWriteDone)
{
_writeService.index().execute(
makeLambdaTask([serialNum, lid, futureDoc = std::move(futureDoc),
- immediateCommit, onWriteDone = std::move(onWriteDone), this]() mutable {
- performIndexPut(serialNum, lid, std::move(futureDoc), immediateCommit, std::move(onWriteDone));
+ onWriteDone = std::move(onWriteDone), this]() mutable {
+ performIndexPut(serialNum, lid, std::move(futureDoc), std::move(onWriteDone));
}));
}
void
-SearchableFeedView::removeIndexedFields(SerialNum serialNum, search::DocumentIdT lid,
- bool immediateCommit, OnRemoveDoneType onWriteDone)
+SearchableFeedView::removeIndexedFields(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone)
{
if (!hasIndexedFields()) {
return;
}
_writeService.index().execute(
makeLambdaTask([=]() {
- performIndexRemove(serialNum, lid, immediateCommit, onWriteDone);
+ performIndexRemove(serialNum, lid, onWriteDone);
}));
}
void
-SearchableFeedView::performIndexRemove(SerialNum serialNum, search::DocumentIdT lid,
- bool immediateCommit, OnRemoveDoneType onWriteDone)
+SearchableFeedView::performIndexRemove(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone)
{
+ (void) onWriteDone;
assert(_writeService.index().isCurrentThread());
VLOG(getDebugLevel(lid, nullptr),
"database(%s): performIndexRemove: serialNum(%" PRIu64 "), lid(%d)",
_params._docTypeName.toString().c_str(), serialNum, lid);
_indexWriter->remove(serialNum, lid);
- if (immediateCommit) {
- _indexWriter->commit(serialNum, onWriteDone);
- }
}
void
-SearchableFeedView::performIndexRemove(SerialNum serialNum, const LidVector &lidsToRemove,
- bool immediateCommit, OnWriteDoneType onWriteDone)
+SearchableFeedView::performIndexRemove(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone)
{
+ (void) onWriteDone;
assert(_writeService.index().isCurrentThread());
for (const auto lid : lidsToRemove) {
VLOG(getDebugLevel(lid, nullptr),
@@ -166,21 +156,18 @@ SearchableFeedView::performIndexRemove(SerialNum serialNum, const LidVector &lid
_indexWriter->remove(serialNum, lid);
}
- if (immediateCommit) {
- _indexWriter->commit(serialNum, onWriteDone);
- }
}
void
SearchableFeedView::removeIndexedFields(SerialNum serialNum, const LidVector &lidsToRemove,
- bool immediateCommit, OnWriteDoneType onWriteDone)
+ OnWriteDoneType onWriteDone)
{
if (!hasIndexedFields())
return;
_writeService.index().execute(
makeLambdaTask([=]() {
- performIndexRemove(serialNum, lidsToRemove, immediateCommit, onWriteDone);
+ performIndexRemove(serialNum, lidsToRemove, onWriteDone);
}));
}
diff --git a/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.h b/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.h
index 3265bc0ae70..944d383e06d 100644
--- a/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.h
+++ b/searchcore/src/vespa/searchcore/proton/server/searchable_feed_view.h
@@ -34,38 +34,21 @@ private:
bool hasIndexedFields() const { return _hasIndexedFields; }
- void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const Document &doc,
- bool immediateCommit, OnOperationDoneType onWriteDone);
-
- void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &doc,
- bool immediateCommit, OnOperationDoneType onWriteDone);
- void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, FutureDoc doc,
- bool immediateCommit, OnOperationDoneType onWriteDone);
-
- void performIndexRemove(SerialNum serialNum, search::DocumentIdT lid,
- bool immediateCommit, OnRemoveDoneType onWriteDone);
-
- void performIndexRemove(SerialNum serialNum, const LidVector &lidsToRemove,
- bool immediateCommit, OnWriteDoneType onWriteDone);
-
+ void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const Document &doc, OnOperationDoneType onWriteDone);
+ void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &doc, OnOperationDoneType onWriteDone);
+ void performIndexPut(SerialNum serialNum, search::DocumentIdT lid, FutureDoc doc, OnOperationDoneType onWriteDone);
+ void performIndexRemove(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone);
+ void performIndexRemove(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone);
void performIndexHeartBeat(SerialNum serialNum);
-
void internalDeleteBucket(const DeleteBucketOperation &delOp) override;
void performSync();
void heartBeatIndexedFields(SerialNum serialNum) override;
- void putIndexedFields(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &newDoc,
- bool immediateCommit, OnOperationDoneType onWriteDone) override;
-
- void updateIndexedFields(SerialNum serialNum, search::DocumentIdT lid, FutureDoc newDoc,
- bool immediateCommit, OnOperationDoneType onWriteDone) override;
-
- void removeIndexedFields(SerialNum serialNum, search::DocumentIdT lid,
- bool immediateCommit, OnRemoveDoneType onWriteDone) override;
-
- void removeIndexedFields(SerialNum serialNum, const LidVector &lidsToRemove,
- bool immediateCommit, OnWriteDoneType onWriteDone) override;
+ void putIndexedFields(SerialNum serialNum, search::DocumentIdT lid, const DocumentSP &newDoc, OnOperationDoneType onWriteDone) override;
+ void updateIndexedFields(SerialNum serialNum, search::DocumentIdT lid, FutureDoc newDoc, OnOperationDoneType onWriteDone) override;
+ void removeIndexedFields(SerialNum serialNum, search::DocumentIdT lid, OnRemoveDoneType onWriteDone) override;
+ void removeIndexedFields(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone) override;
void performIndexForceCommit(SerialNum serialNum, OnForceCommitDoneType onCommitDone);
void internalForceCommit(SerialNum serialNum, OnForceCommitDoneType onCommitDone) override;
diff --git a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp
index 3db55cf6755..186c321d920 100644
--- a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.cpp
@@ -288,10 +288,10 @@ StoreOnlyFeedView::considerEarlyAck(FeedToken & token)
}
void
-StoreOnlyFeedView::putAttributes(SerialNum, Lid, const Document &, bool, OnPutDoneType) {}
+StoreOnlyFeedView::putAttributes(SerialNum, Lid, const Document &, OnPutDoneType) {}
void
-StoreOnlyFeedView::putIndexedFields(SerialNum, Lid, const Document::SP &, bool, OnOperationDoneType) {}
+StoreOnlyFeedView::putIndexedFields(SerialNum, Lid, const Document::SP &, OnOperationDoneType) {}
void
StoreOnlyFeedView::preparePut(PutOperation &putOp)
@@ -334,15 +334,14 @@ StoreOnlyFeedView::internalPut(FeedToken token, const PutOperation &putOp)
bool docAlreadyExists = putOp.getValidPrevDbdId(_params._subDbId);
if (putOp.getValidDbdId(_params._subDbId)) {
- bool immediateCommit = needImmediateCommit();
const document::GlobalId &gid = docId.getGlobalId();
std::shared_ptr<PutDoneContext> onWriteDone =
createPutDoneContext(std::move(token), std::move(uncommitted),
_gidToLidChangeHandler, doc, gid, putOp.getLid(), serialNum,
putOp.changedDbdId() && useDocumentMetaStore(serialNum));
putSummary(serialNum, putOp.getLid(), doc, onWriteDone);
- putAttributes(serialNum, putOp.getLid(), *doc, immediateCommit, onWriteDone);
- putIndexedFields(serialNum, putOp.getLid(), doc, immediateCommit, onWriteDone);
+ putAttributes(serialNum, putOp.getLid(), *doc, onWriteDone);
+ putIndexedFields(serialNum, putOp.getLid(), doc, onWriteDone);
}
if (docAlreadyExists && putOp.changedDbdId()) {
assert(!putOp.getValidDbdId(_params._subDbId));
@@ -369,7 +368,7 @@ void
StoreOnlyFeedView::heartBeatAttributes(SerialNum ) {}
void
-StoreOnlyFeedView::updateAttributes(SerialNum, Lid, const DocumentUpdate & upd, bool,
+StoreOnlyFeedView::updateAttributes(SerialNum, Lid, const DocumentUpdate & upd,
OnOperationDoneType, IFieldUpdateCallback & onUpdate)
{
for (const auto & fieldUpdate : upd.getUpdates()) {
@@ -378,12 +377,12 @@ StoreOnlyFeedView::updateAttributes(SerialNum, Lid, const DocumentUpdate & upd,
}
void
-StoreOnlyFeedView::updateAttributes(SerialNum, Lid, FutureDoc, bool, OnOperationDoneType)
+StoreOnlyFeedView::updateAttributes(SerialNum, Lid, FutureDoc, OnOperationDoneType)
{
}
void
-StoreOnlyFeedView::updateIndexedFields(SerialNum, Lid, FutureDoc, bool, OnOperationDoneType)
+StoreOnlyFeedView::updateIndexedFields(SerialNum, Lid, FutureDoc, OnOperationDoneType)
{
}
@@ -495,10 +494,9 @@ StoreOnlyFeedView::internalUpdate(FeedToken token, const UpdateOperation &updOp)
auto uncommitted = get_pending_lid_token(updOp);
considerEarlyAck(token);
- bool immediateCommit = needImmediateCommit();
auto onWriteDone = createUpdateDoneContext(std::move(token), std::move(uncommitted), updOp.getUpdate());
UpdateScope updateScope(*_schema, upd);
- updateAttributes(serialNum, lid, upd, immediateCommit, onWriteDone, updateScope);
+ updateAttributes(serialNum, lid, upd, onWriteDone, updateScope);
if (updateScope.hasIndexOrNonAttributeFields()) {
PromisedDoc promisedDoc;
@@ -506,7 +504,7 @@ StoreOnlyFeedView::internalUpdate(FeedToken token, const UpdateOperation &updOp)
onWriteDone->setDocument(futureDoc);
_pendingLidsForDocStore.waitComplete(lid);
if (updateScope._indexedFields) {
- updateIndexedFields(serialNum, lid, futureDoc, immediateCommit, onWriteDone);
+ updateIndexedFields(serialNum, lid, futureDoc, onWriteDone);
}
PromisedStream promisedStream;
FutureStream futureStream = promisedStream.get_future();
@@ -522,7 +520,7 @@ StoreOnlyFeedView::internalUpdate(FeedToken token, const UpdateOperation &updOp)
makeUpdatedDocument(serialNum, lid, *upd, onWriteDone,
std::move(promisedDoc), std::move(promisedStream));
}));
- updateAttributes(serialNum, lid, std::move(futureDoc), immediateCommit, onWriteDone);
+ updateAttributes(serialNum, lid, std::move(futureDoc), onWriteDone);
}
}
@@ -576,10 +574,10 @@ StoreOnlyFeedView::lookupDocId(const DocumentId &docId, Lid &lid) const
}
void
-StoreOnlyFeedView::removeAttributes(SerialNum, Lid, bool, OnRemoveDoneType) {}
+StoreOnlyFeedView::removeAttributes(SerialNum, Lid, OnRemoveDoneType) {}
void
-StoreOnlyFeedView::removeIndexedFields(SerialNum, Lid, bool, OnRemoveDoneType) {}
+StoreOnlyFeedView::removeIndexedFields(SerialNum, Lid, OnRemoveDoneType) {}
void
StoreOnlyFeedView::prepareRemove(RemoveOperation &rmOp)
@@ -666,9 +664,8 @@ StoreOnlyFeedView::internalRemove(FeedToken token, IPendingLidTracker::Token unc
std::move(pendingNotifyRemoveDone), (explicitReuseLid ? lid : 0u),
std::move(moveDoneCtx));
removeSummary(serialNum, lid, onWriteDone);
- bool immediateCommit = needImmediateCommit();
- removeAttributes(serialNum, lid, immediateCommit, onWriteDone);
- removeIndexedFields(serialNum, lid, immediateCommit, onWriteDone);
+ removeAttributes(serialNum, lid, onWriteDone);
+ removeIndexedFields(serialNum, lid, onWriteDone);
}
PendingNotifyRemoveDone
@@ -699,14 +696,13 @@ StoreOnlyFeedView::adjustMetaStore(const DocumentOperation &op, const GlobalId &
}
void
-StoreOnlyFeedView::removeAttributes(SerialNum, const LidVector &, bool , OnWriteDoneType ) {}
+StoreOnlyFeedView::removeAttributes(SerialNum, const LidVector &, OnWriteDoneType ) {}
void
-StoreOnlyFeedView::removeIndexedFields(SerialNum , const LidVector &, bool , OnWriteDoneType ) {}
+StoreOnlyFeedView::removeIndexedFields(SerialNum , const LidVector &, OnWriteDoneType ) {}
size_t
-StoreOnlyFeedView::removeDocuments(const RemoveDocumentsOperation &op, bool remove_index_and_attributes,
- bool immediateCommit)
+StoreOnlyFeedView::removeDocuments(const RemoveDocumentsOperation &op, bool remove_index_and_attributes)
{
const SerialNum serialNum = op.getSerialNum();
const LidVectorContext::SP &ctx = op.getLidsToRemove(_params._subDbId);
@@ -744,8 +740,8 @@ StoreOnlyFeedView::removeDocuments(const RemoveDocumentsOperation &op, bool remo
onWriteDone = std::make_shared<RemoveBatchDoneContext>(_writeService.master(), std::move(removeBatchDoneTask),
_gidToLidChangeHandler, std::move(gidsToRemove), serialNum);
if (remove_index_and_attributes) {
- removeIndexedFields(serialNum, lidsToRemove, immediateCommit, onWriteDone);
- removeAttributes(serialNum, lidsToRemove, immediateCommit, onWriteDone);
+ removeIndexedFields(serialNum, lidsToRemove, onWriteDone);
+ removeAttributes(serialNum, lidsToRemove, onWriteDone);
}
if (useDocumentStore(serialNum + 1)) {
for (const auto &lid : lidsToRemove) {
@@ -779,8 +775,7 @@ StoreOnlyFeedView::handleDeleteBucket(const DeleteBucketOperation &delOp)
void
StoreOnlyFeedView::internalDeleteBucket(const DeleteBucketOperation &delOp)
{
- bool immediateCommit = needImmediateCommit();
- size_t rm_count = removeDocuments(delOp, true, immediateCommit);
+ size_t rm_count = removeDocuments(delOp, true);
LOG(debug, "internalDeleteBucket(): docType(%s), bucket(%s), lidsToRemove(%zu)",
_params._docTypeName.toString().c_str(), delOp.getBucketId().toString().c_str(), rm_count);
}
@@ -818,15 +813,14 @@ StoreOnlyFeedView::handleMove(const MoveOperation &moveOp, IDestructorCallback::
PendingNotifyRemoveDone pendingNotifyRemoveDone = adjustMetaStore(moveOp, docId.getGlobalId(), docId);
bool docAlreadyExists = moveOp.getValidPrevDbdId(_params._subDbId);
if (moveOp.getValidDbdId(_params._subDbId)) {
- bool immediateCommit = needImmediateCommit();
const document::GlobalId &gid = docId.getGlobalId();
std::shared_ptr<PutDoneContext> onWriteDone =
createPutDoneContext(FeedToken(), _pendingLidsForCommit->produce(moveOp.getLid()),
_gidToLidChangeHandler, doc, gid, moveOp.getLid(), serialNum,
moveOp.changedDbdId() && useDocumentMetaStore(serialNum), doneCtx);
putSummary(serialNum, moveOp.getLid(), doc, onWriteDone);
- putAttributes(serialNum, moveOp.getLid(), *doc, immediateCommit, onWriteDone);
- putIndexedFields(serialNum, moveOp.getLid(), doc, immediateCommit, onWriteDone);
+ putAttributes(serialNum, moveOp.getLid(), *doc, onWriteDone);
+ putIndexedFields(serialNum, moveOp.getLid(), doc, onWriteDone);
}
if (docAlreadyExists && moveOp.changedDbdId()) {
internalRemove(FeedToken(), _pendingLidsForCommit->produce(moveOp.getPrevLid()), serialNum, std::move(pendingNotifyRemoveDone), moveOp.getPrevLid(), doneCtx);
@@ -853,7 +847,7 @@ handlePruneRemovedDocuments(const PruneRemovedDocumentsOperation &pruneOp)
{
assert(_params._subDbType == SubDbType::REMOVED);
assert(pruneOp.getSubDbId() == _params._subDbId);
- uint32_t rm_count = removeDocuments(pruneOp, false, false);
+ uint32_t rm_count = removeDocuments(pruneOp, false);
LOG(debug, "MinimalFeedView::handlePruneRemovedDocuments called, doctype(%s) %u lids pruned, limit %u",
_params._docTypeName.toString().c_str(), rm_count,
diff --git a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h
index 167b246ec0b..da1459d521c 100644
--- a/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h
+++ b/searchcore/src/vespa/searchcore/proton/server/storeonlyfeedview.h
@@ -181,8 +181,7 @@ private:
// Removes documents from meta store and document store.
// returns the number of documents removed.
- size_t removeDocuments(const RemoveDocumentsOperation &op, bool remove_index_and_attribute_fields,
- bool immediateCommit);
+ size_t removeDocuments(const RemoveDocumentsOperation &op, bool remove_index_and_attribute_fields);
void internalRemove(FeedToken token, IPendingLidTracker::Token uncommitted, SerialNum serialNum,
PendingNotifyRemoveDone &&pendingNotifyRemoveDone,
@@ -202,30 +201,20 @@ protected:
virtual void heartBeatAttributes(SerialNum serialNum);
private:
- virtual void putAttributes(SerialNum serialNum, Lid lid, const Document &doc,
- bool immediateCommit, OnPutDoneType onWriteDone);
-
- virtual void putIndexedFields(SerialNum serialNum, Lid lid, const DocumentSP &newDoc,
- bool immediateCommit, OnOperationDoneType onWriteDone);
+ virtual void putAttributes(SerialNum serialNum, Lid lid, const Document &doc, OnPutDoneType onWriteDone);
+ virtual void putIndexedFields(SerialNum serialNum, Lid lid, const DocumentSP &newDoc, OnOperationDoneType onWriteDone);
virtual void updateAttributes(SerialNum serialNum, Lid lid, const DocumentUpdate &upd,
- bool immediateCommit, OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate);
-
- virtual void updateAttributes(SerialNum serialNum, Lid lid, FutureDoc doc,
- bool immediateCommit, OnOperationDoneType onWriteDone);
+ OnOperationDoneType onWriteDone, IFieldUpdateCallback & onUpdate);
- virtual void updateIndexedFields(SerialNum serialNum, Lid lid, FutureDoc doc,
- bool immediateCommit, OnOperationDoneType onWriteDone);
-
- virtual void removeAttributes(SerialNum serialNum, Lid lid, bool immediateCommit, OnRemoveDoneType onWriteDone);
- virtual void removeIndexedFields(SerialNum serialNum, Lid lid, bool immediateCommit, OnRemoveDoneType onWriteDone);
+ virtual void updateAttributes(SerialNum serialNum, Lid lid, FutureDoc doc, OnOperationDoneType onWriteDone);
+ virtual void updateIndexedFields(SerialNum serialNum, Lid lid, FutureDoc doc, OnOperationDoneType onWriteDone);
+ virtual void removeAttributes(SerialNum serialNum, Lid lid, OnRemoveDoneType onWriteDone);
+ virtual void removeIndexedFields(SerialNum serialNum, Lid lid, OnRemoveDoneType onWriteDone);
protected:
- virtual void removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove,
- bool immediateCommit, OnWriteDoneType onWriteDone);
-
- virtual void removeIndexedFields(SerialNum serialNum, const LidVector &lidsToRemove,
- bool immediateCommit, OnWriteDoneType onWriteDone);
+ virtual void removeAttributes(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone);
+ virtual void removeIndexedFields(SerialNum serialNum, const LidVector &lidsToRemove, OnWriteDoneType onWriteDone);
virtual void internalForceCommit(SerialNum serialNum, OnForceCommitDoneType onCommitDone);
public:
StoreOnlyFeedView(const Context &ctx, const PersistentParams &params);
diff --git a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp
index 9142a03ab85..9de47b4a8a9 100644
--- a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp
+++ b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp
@@ -60,9 +60,17 @@ spi::LoadType defaultLoadType(0, "default");
struct TestFileStorComponents;
+document::Bucket
+make_bucket_for_doc(const document::DocumentId& docid)
+{
+ document::BucketIdFactory factory;
+ document::BucketId bucket_id(16, factory.getBucketId(docid).getRawId());
+ return makeDocumentBucket(bucket_id);
+}
+
}
-struct FileStorManagerTest : Test{
+struct FileStorTestBase : Test {
enum {LONG_WAITTIME=60};
unique_ptr<TestServiceLayerApp> _node;
std::unique_ptr<vdstestlib::DirConfig> config;
@@ -71,13 +79,13 @@ struct FileStorManagerTest : Test{
const uint32_t _waitTime;
const document::DocumentType* _testdoctype1;
- FileStorManagerTest() : _node(), _waitTime(LONG_WAITTIME) {}
+ FileStorTestBase() : _node(), _waitTime(LONG_WAITTIME) {}
+ ~FileStorTestBase();
void SetUp() override;
void TearDown() override;
- void createBucket(document::BucketId bid, uint16_t disk)
- {
+ void createBucket(document::BucketId bid, uint16_t disk) {
spi::Context context(defaultLoadType, spi::Priority(0), spi::Trace::TraceLevel(0));
assert(disk == 0u);
_node->getPersistenceProvider().createBucket(makeSpiBucket(bid), context);
@@ -88,11 +96,29 @@ struct FileStorManagerTest : Test{
entry.write();
}
- document::Document::UP createDocument(const std::string& content, const std::string& id)
- {
+ document::Document::UP createDocument(const std::string& content, const std::string& id) {
return _node->getTestDocMan().createDocument(content, id);
}
+ std::shared_ptr<api::PutCommand> make_put_command(StorageMessage::Priority pri = 20,
+ const std::string& docid = "id:foo:testdoctype1::bar",
+ Timestamp timestamp = 100) {
+ Document::SP doc(createDocument("my content", docid).release());
+ auto bucket = make_bucket_for_doc(doc->getId());
+ auto cmd = std::make_shared<api::PutCommand>(bucket, std::move(doc), timestamp);
+ cmd->setPriority(pri);
+ return cmd;
+ }
+
+ std::shared_ptr<api::GetCommand> make_get_command(StorageMessage::Priority pri,
+ const std::string& docid = "id:foo:testdoctype1::bar") {
+ document::DocumentId did(docid);
+ auto bucket = make_bucket_for_doc(did);
+ auto cmd = std::make_shared<api::GetCommand>(bucket, did, document::AllFields::NAME);
+ cmd->setPriority(pri);
+ return cmd;
+ }
+
bool ownsBucket(uint16_t distributorIndex,
const document::BucketId& bucket) const
{
@@ -163,10 +189,12 @@ struct FileStorManagerTest : Test{
const Metric& metric);
auto& thread_metrics_of(FileStorManager& manager) {
- return manager._metrics->disk->threads[0];
+ return manager.get_metrics().disk->threads[0];
}
};
+FileStorTestBase::~FileStorTestBase() = default;
+
std::string findFile(const std::string& path, const std::string& file) {
FastOS_DirectoryScan dirScan(path.c_str());
while (dirScan.ReadNext()) {
@@ -207,7 +235,7 @@ struct TestFileStorComponents {
DummyStorageLink top;
FileStorManager* manager;
- explicit TestFileStorComponents(FileStorManagerTest& test,
+ explicit TestFileStorComponents(FileStorTestBase& test,
bool use_small_config = false)
: manager(new FileStorManager((use_small_config ? test.smallConfig : test.config)->getConfigId(),
test._node->getPersistenceProvider(),
@@ -227,7 +255,7 @@ struct FileStorHandlerComponents {
FileStorMetrics metrics;
std::unique_ptr<FileStorHandler> filestorHandler;
- FileStorHandlerComponents(FileStorManagerTest& test, uint32_t threadsPerDisk = 1)
+ FileStorHandlerComponents(FileStorTestBase& test, uint32_t threadsPerDisk = 1)
: top(),
dummyManager(new DummyStorageLink),
messageSender(*dummyManager),
@@ -253,7 +281,7 @@ struct PersistenceHandlerComponents : public FileStorHandlerComponents {
BucketOwnershipNotifier bucketOwnershipNotifier;
std::unique_ptr<PersistenceHandler> persistenceHandler;
- PersistenceHandlerComponents(FileStorManagerTest& test)
+ PersistenceHandlerComponents(FileStorTestBase& test)
: FileStorHandlerComponents(test),
component(test._node->getComponentRegister(), "test"),
bucketOwnershipNotifier(component, messageSender),
@@ -277,17 +305,21 @@ PersistenceHandlerComponents::~PersistenceHandlerComponents() = default;
}
void
-FileStorManagerTest::SetUp()
+FileStorTestBase::SetUp()
{
setupDisks();
}
void
-FileStorManagerTest::TearDown()
+FileStorTestBase::TearDown()
{
_node.reset(0);
}
+struct FileStorManagerTest : public FileStorTestBase {
+
+};
+
TEST_F(FileStorManagerTest, header_only_put) {
TestFileStorComponents c(*this);
auto& top = c.top;
@@ -947,10 +979,10 @@ TEST_F(FileStorManagerTest, split_single_group) {
}
void
-FileStorManagerTest::putDoc(DummyStorageLink& top,
- FileStorHandler& filestorHandler,
- const document::BucketId& target,
- uint32_t docNum)
+FileStorTestBase::putDoc(DummyStorageLink& top,
+ FileStorHandler& filestorHandler,
+ const document::BucketId& target,
+ uint32_t docNum)
{
api::StorageMessageAddress address("storage", lib::NodeType::STORAGE, 3);
spi::Context context(defaultLoadType, spi::Priority(0),
@@ -1838,7 +1870,7 @@ TEST_F(FileStorManagerTest, create_bucket_sets_active_flag_in_database_and_reply
}
template <typename Metric>
-void FileStorManagerTest::assert_request_size_set(TestFileStorComponents& c, std::shared_ptr<api::StorageMessage> cmd, const Metric& metric) {
+void FileStorTestBase::assert_request_size_set(TestFileStorComponents& c, std::shared_ptr<api::StorageMessage> cmd, const Metric& metric) {
api::StorageMessageAddress address("storage", lib::NodeType::STORAGE, 3);
cmd->setApproxByteSize(54321);
cmd->setAddress(address);
@@ -1965,4 +1997,97 @@ TEST_F(FileStorManagerTest, bucket_db_is_populated_from_provider_when_initialize
EXPECT_EQ(reported_state->getState(), lib::State::UP);
}
+struct FileStorHandlerTest : public FileStorTestBase {
+ std::unique_ptr<FileStorHandlerComponents> c;
+ FileStorHandler* handler;
+ FileStorHandlerTest()
+ : FileStorTestBase(),
+ c(),
+ handler()
+ {}
+ void SetUp() override {
+ FileStorTestBase::SetUp();
+ c = std::make_unique<FileStorHandlerComponents>(*this);
+ handler = c->filestorHandler.get();
+ }
+ FileStorHandler::LockedMessage get_next_message() {
+ return handler->getNextMessage(0);
+ }
+};
+
+void
+expect_async_message(StorageMessage::Priority exp_pri,
+ const FileStorHandler::ScheduleAsyncResult& result)
+{
+ EXPECT_TRUE(result.was_scheduled());
+ ASSERT_TRUE(result.has_async_message());
+ EXPECT_EQ(exp_pri, result.async_message().second->getPriority());
+}
+
+void
+expect_empty_async_message(const FileStorHandler::ScheduleAsyncResult& result)
+{
+ EXPECT_TRUE(result.was_scheduled());
+ EXPECT_FALSE(result.has_async_message());
+}
+
+TEST_F(FileStorHandlerTest, message_not_scheduled_if_handler_is_closed)
+{
+ handler->setDiskState(FileStorHandler::DiskState::CLOSED);
+ auto result = handler->schedule_and_get_next_async_message(make_put_command());
+ EXPECT_FALSE(result.was_scheduled());
+}
+
+TEST_F(FileStorHandlerTest, no_async_message_returned_if_handler_is_paused)
+{
+ auto guard = handler->pause();
+ auto result = handler->schedule_and_get_next_async_message(make_put_command());
+ expect_empty_async_message(result);
+}
+
+TEST_F(FileStorHandlerTest, async_message_with_lowest_pri_returned_on_schedule)
+{
+ handler->schedule(make_put_command(20));
+ handler->schedule(make_put_command(40));
+ {
+ auto result = handler->schedule_and_get_next_async_message(make_put_command(30));
+ expect_async_message(20, result);
+ }
+ EXPECT_EQ(30, get_next_message().second->getPriority());
+ EXPECT_EQ(40, get_next_message().second->getPriority());
+}
+
+TEST_F(FileStorHandlerTest, no_async_message_returned_if_lowest_pri_message_is_not_async)
+{
+ // GET is not an async message.
+ handler->schedule(make_get_command(20));
+
+ auto result = handler->schedule_and_get_next_async_message(make_put_command(30));
+ expect_empty_async_message(result);
+
+ EXPECT_EQ(20, get_next_message().second->getPriority());
+ EXPECT_EQ(30, get_next_message().second->getPriority());
+}
+
+TEST_F(FileStorHandlerTest, inhibited_operations_are_skipped)
+{
+ std::string docid_a = "id:foo:testdoctype1::a";
+ std::string docid_b = "id:foo:testdoctype1::b";
+ handler->schedule(make_put_command(20, docid_a));
+ {
+ auto locked_msg = get_next_message();
+ {
+ // Bucket for docid_a is locked and put command for same bucket is inhibited.
+ auto result = handler->schedule_and_get_next_async_message(make_put_command(30, docid_a));
+ expect_empty_async_message(result);
+ }
+ {
+ // Put command for another bucket is ok.
+ auto result = handler->schedule_and_get_next_async_message(make_put_command(40, docid_b));
+ expect_async_message(40, result);
+ }
+ }
+ EXPECT_EQ(30, get_next_message().second->getPriority());
+}
+
} // storage
diff --git a/storage/src/vespa/storage/bucketdb/btree_lockable_map.h b/storage/src/vespa/storage/bucketdb/btree_lockable_map.h
index ea3a7838d43..6e42a721732 100644
--- a/storage/src/vespa/storage/bucketdb/btree_lockable_map.h
+++ b/storage/src/vespa/storage/bucketdb/btree_lockable_map.h
@@ -37,7 +37,7 @@ public:
using BucketId = document::BucketId;
BTreeLockableMap();
- ~BTreeLockableMap();
+ ~BTreeLockableMap() override;
bool operator==(const BTreeLockableMap& other) const;
bool operator!=(const BTreeLockableMap& other) const {
diff --git a/storage/src/vespa/storage/persistence/asynchandler.cpp b/storage/src/vespa/storage/persistence/asynchandler.cpp
index 1d1f5caf673..5344553dd45 100644
--- a/storage/src/vespa/storage/persistence/asynchandler.cpp
+++ b/storage/src/vespa/storage/persistence/asynchandler.cpp
@@ -182,6 +182,19 @@ AsyncHandler::handleRemove(api::RemoveCommand& cmd, MessageTracker::UP trackerUP
}
bool
+AsyncHandler::is_async_message(api::MessageType::Id type_id) noexcept
+{
+ switch (type_id) {
+ case api::MessageType::PUT_ID:
+ case api::MessageType::UPDATE_ID:
+ case api::MessageType::REMOVE_ID:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool
AsyncHandler::tasConditionExists(const api::TestAndSetCommand & cmd) {
return cmd.getCondition().isPresent();
}
diff --git a/storage/src/vespa/storage/persistence/asynchandler.h b/storage/src/vespa/storage/persistence/asynchandler.h
index c25f2ea0be6..92bf72e7c51 100644
--- a/storage/src/vespa/storage/persistence/asynchandler.h
+++ b/storage/src/vespa/storage/persistence/asynchandler.h
@@ -25,6 +25,7 @@ public:
MessageTrackerUP handlePut(api::PutCommand& cmd, MessageTrackerUP tracker) const;
MessageTrackerUP handleRemove(api::RemoveCommand& cmd, MessageTrackerUP tracker) const;
MessageTrackerUP handleUpdate(api::UpdateCommand& cmd, MessageTrackerUP tracker) const;
+ static bool is_async_message(api::MessageType::Id type_id) noexcept;
private:
static bool tasConditionExists(const api::TestAndSetCommand & cmd);
bool tasConditionMatches(const api::TestAndSetCommand & cmd, MessageTracker & tracker,
diff --git a/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h b/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h
index 44e768c9db7..aafc87aa84f 100644
--- a/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h
+++ b/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h
@@ -58,6 +58,30 @@ public:
};
using LockedMessage = std::pair<BucketLockInterface::SP, api::StorageMessage::SP>;
+ class ScheduleAsyncResult {
+ private:
+ bool _was_scheduled;
+ LockedMessage _async_message;
+
+ public:
+ ScheduleAsyncResult() : _was_scheduled(false), _async_message() {}
+ explicit ScheduleAsyncResult(LockedMessage&& async_message_in)
+ : _was_scheduled(true),
+ _async_message(std::move(async_message_in))
+ {}
+ bool was_scheduled() const {
+ return _was_scheduled;
+ }
+ bool has_async_message() const {
+ return _async_message.first.get() != nullptr;
+ }
+ const LockedMessage& async_message() const {
+ return _async_message;
+ }
+ LockedMessage&& release_async_message() {
+ return std::move(_async_message);
+ }
+ };
enum DiskState {
AVAILABLE,
@@ -104,6 +128,11 @@ public:
virtual bool schedule(const std::shared_ptr<api::StorageMessage>&) = 0;
/**
+ * Schedule the given message to be processed and return the next async message to process (if any).
+ */
+ virtual ScheduleAsyncResult schedule_and_get_next_async_message(const std::shared_ptr<api::StorageMessage>& msg) = 0;
+
+ /**
* Used by file stor threads to get their next message to process.
*
* @param stripe The stripe to get messages for
diff --git a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp
index 0c34a421c06..14074b65c5c 100644
--- a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp
+++ b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp
@@ -10,6 +10,7 @@
#include <vespa/storage/common/statusmessages.h>
#include <vespa/storage/common/bucketoperationlogger.h>
#include <vespa/storage/common/messagebucket.h>
+#include <vespa/storage/persistence/asynchandler.h>
#include <vespa/storage/persistence/messages.h>
#include <vespa/storageapi/message/stat.h>
#include <vespa/vespalib/stllike/hash_map.hpp>
@@ -258,6 +259,16 @@ FileStorHandlerImpl::schedule(const std::shared_ptr<api::StorageMessage>& msg)
return false;
}
+FileStorHandler::ScheduleAsyncResult
+FileStorHandlerImpl::schedule_and_get_next_async_message(const std::shared_ptr<api::StorageMessage>& msg)
+{
+ if (getState() == FileStorHandler::AVAILABLE) {
+ document::Bucket bucket = getStorageMessageBucket(*msg);
+ return ScheduleAsyncResult(stripe(bucket).schedule_and_get_next_async_message(MessageEntry(msg, bucket)));
+ }
+ return {};
+}
+
bool
FileStorHandlerImpl::messageMayBeAborted(const api::StorageMessage& msg)
{
@@ -911,6 +922,24 @@ FileStorHandlerImpl::Stripe::getNextMessage(vespalib::duration timeout)
}
FileStorHandler::LockedMessage
+FileStorHandlerImpl::Stripe::get_next_async_message(monitor_guard& guard)
+{
+ if (_owner.isClosed() || _owner.isPaused()) {
+ return {};
+ }
+ PriorityIdx& idx(bmi::get<1>(*_queue));
+ PriorityIdx::iterator iter(idx.begin()), end(idx.end());
+
+ while ((iter != end) && operationIsInhibited(guard, iter->_bucket, *iter->_command)) {
+ ++iter;
+ }
+ if ((iter != end) && AsyncHandler::is_async_message(iter->_command->getType().getId())) {
+ return getMessage(guard, idx, iter);
+ }
+ return {};
+}
+
+FileStorHandler::LockedMessage
FileStorHandlerImpl::Stripe::getMessage(monitor_guard & guard, PriorityIdx & idx, PriorityIdx::iterator iter) {
api::StorageMessage & m(*iter->_command);
@@ -989,6 +1018,14 @@ bool FileStorHandlerImpl::Stripe::schedule(MessageEntry messageEntry)
return true;
}
+FileStorHandler::LockedMessage
+FileStorHandlerImpl::Stripe::schedule_and_get_next_async_message(MessageEntry entry)
+{
+ std::unique_lock guard(*_lock);
+ _queue->emplace_back(std::move(entry));
+ return get_next_async_message(guard);
+}
+
void
FileStorHandlerImpl::Stripe::flush()
{
diff --git a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h
index 6aac8b0474b..549de164229 100644
--- a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h
+++ b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h
@@ -101,6 +101,7 @@ public:
~Stripe();
void flush();
bool schedule(MessageEntry messageEntry);
+ FileStorHandler::LockedMessage schedule_and_get_next_async_message(MessageEntry entry);
void waitUntilNoLocks() const;
void abort(std::vector<std::shared_ptr<api::StorageReply>> & aborted, const AbortBucketOperationsCommand& cmd);
void waitInactive(const AbortBucketOperationsCommand& cmd) const;
@@ -137,6 +138,8 @@ public:
void setMetrics(FileStorStripeMetrics * metrics) { _metrics = metrics; }
private:
bool hasActive(monitor_guard & monitor, const AbortBucketOperationsCommand& cmd) const;
+ FileStorHandler::LockedMessage get_next_async_message(monitor_guard& guard);
+
// Precondition: the bucket used by `iter`s operation is not locked in a way that conflicts
// with its locking requirements.
FileStorHandler::LockedMessage getMessage(monitor_guard & guard, PriorityIdx & idx,
@@ -184,6 +187,7 @@ public:
DiskState getDiskState() const override;
void close() override;
bool schedule(const std::shared_ptr<api::StorageMessage>&) override;
+ ScheduleAsyncResult schedule_and_get_next_async_message(const std::shared_ptr<api::StorageMessage>& msg) override;
FileStorHandler::LockedMessage getNextMessage(uint32_t stripeId) override;
diff --git a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp
index 188523af38d..2653391ecfa 100644
--- a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp
+++ b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp
@@ -48,6 +48,7 @@ FileStorManager(const config::ConfigUri & configUri, spi::PersistenceProvider& p
_configFetcher(configUri.getContext()),
_threadLockCheckInterval(60),
_failDiskOnError(false),
+ _use_async_message_handling_on_schedule(false),
_metrics(std::make_unique<FileStorMetrics>(_component.getLoadTypes()->getMetricLoadTypes())),
_closed(false),
_lock()
@@ -151,6 +152,7 @@ FileStorManager::configure(std::unique_ptr<vespa::config::content::StorFilestorC
_threadLockCheckInterval = config->diskOperationTimeout;
_failDiskOnError = (config->failDiskAfterErrorCount > 0);
+ _use_async_message_handling_on_schedule = config->useAsyncMessageHandlingOnSchedule;
if (!liveUpdate) {
_config = std::move(config);
@@ -258,10 +260,20 @@ FileStorManager::handlePersistenceMessage(const shared_ptr<api::StorageMessage>&
api::ReturnCode errorCode(api::ReturnCode::OK);
LOG(spam, "Received %s. Attempting to queue it.", msg->getType().getName().c_str());
- if (_filestorHandler->schedule(msg)) {
- LOG(spam, "Received persistence message %s. Queued it to disk",
- msg->getType().getName().c_str());
- return true;
+ if (_use_async_message_handling_on_schedule) {
+ auto result = _filestorHandler->schedule_and_get_next_async_message(msg);
+ if (result.was_scheduled()) {
+ if (result.has_async_message()) {
+ getThreadLocalHandler().processLockedMessage(result.release_async_message());
+ }
+ return true;
+ }
+ } else {
+ if (_filestorHandler->schedule(msg)) {
+ LOG(spam, "Received persistence message %s. Queued it to disk",
+ msg->getType().getName().c_str());
+ return true;
+ }
}
switch (_filestorHandler->getDiskState()) {
case FileStorHandler::DISABLED:
diff --git a/storage/src/vespa/storage/persistence/filestorage/filestormanager.h b/storage/src/vespa/storage/persistence/filestorage/filestormanager.h
index ee66bc7d77c..2953462dd1e 100644
--- a/storage/src/vespa/storage/persistence/filestorage/filestormanager.h
+++ b/storage/src/vespa/storage/persistence/filestorage/filestormanager.h
@@ -65,14 +65,13 @@ class FileStorManager : public StorageLinkQueued,
config::ConfigFetcher _configFetcher;
uint32_t _threadLockCheckInterval; // In seconds
bool _failDiskOnError;
+ bool _use_async_message_handling_on_schedule;
std::shared_ptr<FileStorMetrics> _metrics;
std::unique_ptr<FileStorHandler> _filestorHandler;
std::unique_ptr<vespalib::ISequencedTaskExecutor> _sequencedExecutor;
bool _closed;
std::mutex _lock;
- friend struct FileStorManagerTest;
-
public:
FileStorManager(const config::ConfigUri &, spi::PersistenceProvider&,
ServiceLayerComponentRegister&, DoneInitializeHandler&);
@@ -105,6 +104,8 @@ public:
// yet at that point in time.
void initialize_bucket_databases_from_provider();
+ const FileStorMetrics& get_metrics() const { return *_metrics; }
+
private:
void configure(std::unique_ptr<vespa::config::content::StorFilestorConfig> config) override;
PersistenceHandler & createRegisteredHandler(const ServiceLayerComponent & component);
diff --git a/storage/src/vespa/storage/persistence/mergehandler.cpp b/storage/src/vespa/storage/persistence/mergehandler.cpp
index 4fe7333fb5f..c7c681a838b 100644
--- a/storage/src/vespa/storage/persistence/mergehandler.cpp
+++ b/storage/src/vespa/storage/persistence/mergehandler.cpp
@@ -403,9 +403,9 @@ MergeHandler::fetchLocalData(
|| (entries.empty() && alreadyFilled == 0))
{
remainingSize -= entry->getSize();
+ entries.push_back(std::move(entry));
LOG(spam, "Added %s, remainingSize is %u",
entries.back()->toString().c_str(), remainingSize);
- entries.push_back(std::move(entry));
} else {
LOG(spam, "Adding %s would exceed chunk size limit of %u; "
"not filling up any more diffs for current round",
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java
index eaf83238145..33cb6d7d5d4 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/DefaultZmsClient.java
@@ -9,8 +9,10 @@ import com.yahoo.vespa.athenz.api.OktaAccessToken;
import com.yahoo.vespa.athenz.api.OktaIdentityToken;
import com.yahoo.vespa.athenz.client.common.ClientBase;
import com.yahoo.vespa.athenz.client.zms.bindings.AccessResponseEntity;
+import com.yahoo.vespa.athenz.client.zms.bindings.AssertionEntity;
import com.yahoo.vespa.athenz.client.zms.bindings.DomainListResponseEntity;
import com.yahoo.vespa.athenz.client.zms.bindings.MembershipResponseEntity;
+import com.yahoo.vespa.athenz.client.zms.bindings.PolicyEntity;
import com.yahoo.vespa.athenz.client.zms.bindings.ProviderResourceGroupRolesRequestEntity;
import com.yahoo.vespa.athenz.client.zms.bindings.TenancyRequestEntity;
import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider;
@@ -23,6 +25,8 @@ import javax.net.ssl.SSLContext;
import java.net.URI;
import java.util.Collections;
import java.util.List;
+import java.util.Optional;
+import java.util.OptionalInt;
import java.util.Set;
import java.util.function.Supplier;
@@ -149,6 +153,47 @@ public class DefaultZmsClient extends ClientBase implements ZmsClient {
});
}
+ @Override
+ public void addPolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole) {
+ URI uri = zmsUrl.resolve(String.format("domain/%s/policy/%s/assertion",
+ athenzDomain.getName(), athenzPolicy));
+ HttpUriRequest request = RequestBuilder.put()
+ .setUri(uri)
+ .setEntity(toJsonStringEntity(new AssertionEntity(athenzRole.toResourceNameString(), resourceName.toResourceNameString(), action)))
+ .build();
+ execute(request, response -> readEntity(response, Void.class));
+ }
+
+ @Override
+ public boolean deletePolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole) {
+ URI uri = zmsUrl.resolve(String.format("domain/%s/policy/%s",
+ athenzDomain.getName(), athenzPolicy));
+ HttpUriRequest request = RequestBuilder.get()
+ .setUri(uri)
+ .build();
+ PolicyEntity policyEntity = execute(request, response -> readEntity(response, PolicyEntity.class));
+
+ OptionalInt assertionId = policyEntity.getAssertions().stream()
+ .filter(assertionEntity -> assertionEntity.getAction().equals(action) &&
+ assertionEntity.getResource().equals(resourceName.toResourceNameString()) &&
+ assertionEntity.getRole().equals(athenzRole.toResourceNameString()))
+ .mapToInt(AssertionEntity::getId).findFirst();
+
+ if (assertionId.isEmpty()) {
+ return false;
+ }
+
+ uri = zmsUrl.resolve(String.format("domain/%s/policy/%s/assertion/%d",
+ athenzDomain.getName(), athenzPolicy, assertionId.getAsInt()));
+
+ request = RequestBuilder.delete()
+ .setUri(uri)
+ .build();
+
+ execute(request, response -> readEntity(response, Void.class));
+ return true;
+ }
+
private static Header createCookieHeaderWithOktaTokens(OktaIdentityToken identityToken, OktaAccessToken accessToken) {
return new BasicHeader("Cookie", String.format("okta_at=%s; okta_it=%s", accessToken.token(), identityToken.token()));
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/ZmsClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/ZmsClient.java
index 12762534bd4..c7f865a58bb 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/ZmsClient.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/ZmsClient.java
@@ -38,5 +38,9 @@ public interface ZmsClient extends AutoCloseable {
boolean hasAccess(AthenzResourceName resource, String action, AthenzIdentity identity);
+ void addPolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole);
+
+ boolean deletePolicyRule(AthenzDomain athenzDomain, String athenzPolicy, String action, AthenzResourceName resourceName, AthenzRole athenzRole);
+
void close();
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/AssertionEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/AssertionEntity.java
new file mode 100644
index 00000000000..824aa3b4606
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/AssertionEntity.java
@@ -0,0 +1,52 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.client.zms.bindings;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+/**
+ * @author olaa
+ */
+@JsonInclude(JsonInclude.Include.NON_NULL)
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class AssertionEntity {
+
+ private final String role;
+ private final String resource;
+ private final String action;
+ private final Integer id;
+
+
+ public AssertionEntity(String role, String resource, String action) {
+ this(role, resource, action, null);
+ }
+
+ public AssertionEntity(@JsonProperty("role") String role,
+ @JsonProperty("resource") String resource,
+ @JsonProperty("action") String action,
+ @JsonProperty("id") Integer id) {
+ this.role = role;
+ this.resource = resource;
+ this.action = action;
+ this.id = id;
+ }
+
+ public String getRole() {
+ return role;
+ }
+
+ public String getResource() {
+ return resource;
+ }
+
+ public String getAction() {
+ return action;
+ }
+
+ @JsonIgnore
+ public int getId() {
+ return id;
+ }
+}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/PolicyEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/PolicyEntity.java
new file mode 100644
index 00000000000..ebc0997cb09
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/zms/bindings/PolicyEntity.java
@@ -0,0 +1,33 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.client.zms.bindings;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import java.util.List;
+
+/**
+ * @author olaa
+ */
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class PolicyEntity {
+
+ @JsonInclude(JsonInclude.Include.NON_EMPTY)
+ private final List<AssertionEntity> assertions;
+ private final String name;
+
+ public PolicyEntity(@JsonProperty("name") String name,
+ @JsonProperty("assertions") List<AssertionEntity> assertions) {
+ this.name = name;
+ this.assertions = assertions;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public List<AssertionEntity> getAssertions() {
+ return assertions;
+ }
+}
diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java
index 13ca774dc33..4fdac7b584a 100644
--- a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java
+++ b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java
@@ -244,13 +244,19 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler {
@Override
public void destroy() {
executor.shutdown();
+ Instant doom = clock.instant().plus(Duration.ofSeconds(20));
+ while ( ! operations.isEmpty() && clock.instant().isBefore(doom))
+ dispatchEnqueued();
+
+ if ( ! operations.isEmpty())
+ log.log(WARNING, "Failed to empty request queue before shutdown timeout — " + operations.size() + " requests left");
+
+ asyncSession.destroy();
visits.values().forEach(VisitorSession::destroy);
+
try {
- if ( ! executor.awaitTermination(10, TimeUnit.SECONDS)) {
+ if ( ! executor.awaitTermination(Duration.between(clock.instant(), doom).toMillis(), TimeUnit.MILLISECONDS))
executor.shutdownNow();
- if ( ! executor.awaitTermination(10, TimeUnit.SECONDS))
- log.log(WARNING, "Failed shutting down /document/v1 executor within 20 seconds");
- }
}
catch (InterruptedException e) {
log.log(WARNING, "Interrupted waiting for /document/v1 executor to shut down");
@@ -729,13 +735,12 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler {
jsonResponse.commit(Response.Status.PRECONDITION_FAILED);
break;
case INSUFFICIENT_STORAGE:
- log.log(WARNING, "Insufficient storage left in cluster: " + response.getTextMessage());
jsonResponse.commit(Response.Status.INSUFFICIENT_STORAGE);
break;
default:
log.log(WARNING, "Unexpected document API operation outcome '" + response.outcome() + "'");
case ERROR:
- log.log(WARNING, "Exception performing document operation: " + response.getTextMessage());
+ log.log(FINE, () -> "Exception performing document operation: " + response.getTextMessage());
jsonResponse.commit(Response.Status.INTERNAL_SERVER_ERROR);
}
}