summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java3
-rw-r--r--config-model/src/main/protobuf/onnx.proto517
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java10
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java2
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java45
-rw-r--r--container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java34
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java6
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java7
-rw-r--r--model-integration/src/main/protobuf/onnx.proto517
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java4
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java9
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java1
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java2
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java10
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java77
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java8
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java4
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java6
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java433
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java14
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java2
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java2
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java2
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java7
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java3
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java5
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java6
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java3
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java6
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java6
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java2
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h11
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp13
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp1
-rw-r--r--searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp19
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp1
-rw-r--r--storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp1
-rw-r--r--vespalib/src/tests/datastore/array_store/array_store_test.cpp1
-rw-r--r--vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp74
-rw-r--r--vespalib/src/vespa/vespalib/datastore/array_store.h2
-rw-r--r--vespalib/src/vespa/vespalib/datastore/array_store.hpp14
-rw-r--r--vespalib/src/vespa/vespalib/datastore/array_store_config.cpp20
-rw-r--r--vespalib/src/vespa/vespalib/datastore/array_store_config.h4
48 files changed, 1494 insertions, 439 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
index 1984ceadac6..8edd446b209 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -274,7 +274,8 @@ public class OnnxModelInfo {
static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException {
g.writeStartObject();
g.writeStringField("name", valueInfo.getName());
- g.writeStringField("type", onnxValueTypeToString(valueInfo.getType().getTensorType().getElemType()));
+ var elemType = Onnx.TensorProto.DataType.forNumber(valueInfo.getType().getTensorType().getElemType());
+ g.writeStringField("type", onnxValueTypeToString(elemType));
g.writeArrayFieldStart("dim");
for (Onnx.TensorShapeProto.Dimension dim : valueInfo.getType().getTensorType().getShape().getDimList()) {
g.writeStartObject();
diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto
index dc6542867e0..1d265ae9f28 100644
--- a/config-model/src/main/protobuf/onnx.proto
+++ b/config-model/src/main/protobuf/onnx.proto
@@ -3,8 +3,8 @@
//
-// Copyright (c) Facebook Inc. and Microsoft Corporation.
-// Licensed under the MIT license.
+// SPDX-License-Identifier: Apache-2.0
+
syntax = "proto2";
@@ -20,23 +20,16 @@ package onnx;
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
-// Intermediate Representation, or 'IR' for short.
+// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
-// Release
-//
-// We are still in the very early stage of defining ONNX. The current
-// version of ONNX is a starting point. While we are actively working
-// towards a complete spec, we would like to get the community involved
-// by sharing our working version of ONNX.
-//
// Protobuf compatibility
-//
-// To simplify framework compatibility, ONNX is defined using the subset of protobuf
+//
+// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
@@ -60,22 +53,60 @@ enum Version {
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
- // control. We should use version as
- // xx(major) - xx(minor) - xxxx(bugfix)
- // and we are starting with 0x00000001 (0.0.1), which was the
- // version we published on Oct 10, 2017.
- IR_VERSION_2017_10_10 = 0x00000001;
+ // control.
+ // For the IR, we are using simple numbers starting with 0x00000001,
+ // which was the version we published on Oct 10, 2017.
+ IR_VERSION_2017_10_10 = 0x0000000000000001;
- // IR_VERSION 0.0.2 published on Oct 30, 2017
+ // IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
- IR_VERSION_2017_10_30 = 0x00000002;
+ IR_VERSION_2017_10_30 = 0x0000000000000002;
- // IR VERSION 0.0.3 published on Nov 3, 2017
+ // IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
- IR_VERSION = 0x00000003;
+ IR_VERSION_2017_11_3 = 0x0000000000000003;
+
+ // IR VERSION 4 published on Jan 22, 2019
+ // - Relax constraint that initializers should be a subset of graph inputs
+ // - Add type BFLOAT16
+ IR_VERSION_2019_1_22 = 0x0000000000000004;
+
+ // IR VERSION 5 published on March 18, 2019
+ // - Add message TensorAnnotation.
+ // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
+ IR_VERSION_2019_3_18 = 0x0000000000000005;
+
+ // IR VERSION 6 published on Sep 19, 2019
+ // - Add support for sparse tensor constants stored in model.
+ // - Add message SparseTensorProto
+ // - Add sparse initializers
+ IR_VERSION_2019_9_19 = 0x0000000000000006;
+
+ // IR VERSION 7 published on May 8, 2020
+ // - Add support to allow function body graph to rely on multiple external opreator sets.
+ // - Add a list to promote inference graph's initializers to global and
+ // mutable variables. Global variables are visible in all graphs of the
+ // stored models.
+ // - Add message TrainingInfoProto to store initialization
+ // method and training algorithm. The execution of TrainingInfoProto
+ // can modify the values of mutable variables.
+ // - Implicitly add inference graph into each TrainingInfoProto's algorithm.
+ IR_VERSION_2020_5_8 = 0x0000000000000007;
+
+ // IR VERSION 8 published on July 30, 2021
+ // Introduce TypeProto.SparseTensor
+ // Introduce TypeProto.Optional
+ // Added a list of FunctionProtos local to the model
+ // Deprecated since_version and operator status from FunctionProto
+ IR_VERSION_2021_7_30 = 0x0000000000000008;
+
+ // IR VERSION 9 published on May 5, 2023
+ // Added AttributeProto to FunctionProto so that default attribute values can be set.
+ // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
+ IR_VERSION = 0x0000000000000009;
}
// Attributes
@@ -95,17 +126,21 @@ message AttributeProto {
STRING = 3;
TENSOR = 4;
GRAPH = 5;
+ SPARSE_TENSOR = 11;
+ TYPE_PROTO = 13;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
+ SPARSE_TENSORS = 12;
+ TYPE_PROTOS = 14;
}
// The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute
-
+
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
@@ -117,10 +152,10 @@ message AttributeProto {
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
- // implementations needed to use has_field hueristics to determine
+ // implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
- // change was made to accomodate proto3 implementations.
+ // change was made to accommodate proto3 implementations.
optional AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
@@ -129,14 +164,18 @@ message AttributeProto {
optional bytes s = 4; // UTF-8 string
optional TensorProto t = 5; // tensor value
optional GraphProto g = 6; // graph
+ optional SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
+ optional TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
+ repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
+ repeated TypeProto type_protos = 15;// list of type protos
}
// Defines information on value, including the name, the type, and
@@ -144,7 +183,8 @@ message AttributeProto {
message ValueInfoProto {
// This field MUST be present in this version of the IR.
optional string name = 1; // namespace Value
- // This field MUST be present in this version of the IR.
+ // This field MUST be present in this version of the IR for
+ // inputs and outputs of the top-level graph.
optional TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
optional string doc_string = 3;
@@ -155,7 +195,7 @@ message ValueInfoProto {
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
-// For example, it can be a node of type "Conv" that takes in an image, a filter
+// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
@@ -177,12 +217,130 @@ message NodeProto {
optional string doc_string = 6;
}
+// Training information
+// TrainingInfoProto stores information for training a model.
+// In particular, this defines two functionalities: an initialization-step
+// and a training-algorithm-step. Initialization resets the model
+// back to its original state as if no training has been performed.
+// Training algorithm improves the model based on input data.
+//
+// The semantics of the initialization-step is that the initializers
+// in ModelProto.graph and in TrainingInfoProto.algorithm are first
+// initialized as specified by the initializers in the graph, and then
+// updated by the "initialization_binding" in every instance in
+// ModelProto.training_info.
+//
+// The field "algorithm" defines a computation graph which represents a
+// training algorithm's step. After the execution of a
+// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
+// may be immediately updated. If the targeted training algorithm contains
+// consecutive update steps (such as block coordinate descent methods),
+// the user needs to create a TrainingInfoProto for each step.
+message TrainingInfoProto {
+ // This field describes a graph to compute the initial tensors
+ // upon starting the training process. Initialization graph has no input
+ // and can have multiple outputs. Usually, trainable tensors in neural
+ // networks are randomly initialized. To achieve that, for each tensor,
+ // the user can put a random number operator such as RandomNormal or
+ // RandomUniform in TrainingInfoProto.initialization.node and assign its
+ // random output to the specific tensor using "initialization_binding".
+ // This graph can also set the initializers in "algorithm" in the same
+ // TrainingInfoProto; a use case is resetting the number of training
+ // iteration to zero.
+ //
+ // By default, this field is an empty graph and its evaluation does not
+ // produce any output. Thus, no initializer would be changed by default.
+ optional GraphProto initialization = 1;
+
+ // This field represents a training algorithm step. Given required inputs,
+ // it computes outputs to update initializers in its own or inference graph's
+ // initializer lists. In general, this field contains loss node, gradient node,
+ // optimizer node, increment of iteration count.
+ //
+ // An execution of the training algorithm step is performed by executing the
+ // graph obtained by combining the inference graph (namely "ModelProto.graph")
+ // and the "algorithm" graph. That is, the actual
+ // input/initializer/output/node/value_info/sparse_initializer list of
+ // the training graph is the concatenation of
+ // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
+ // and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
+ // in that order. This combined graph must satisfy the normal ONNX conditions.
+ // Now, let's provide a visualization of graph combination for clarity.
+ // Let the inference graph (i.e., "ModelProto.graph") be
+ // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
+ // and the "algorithm" graph be
+ // tensor_d -> Add -> tensor_e
+ // The combination process results
+ // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
+ //
+ // Notice that an input of a node in the "algorithm" graph may reference the
+ // output of a node in the inference graph (but not the other way round). Also, inference
+ // node cannot reference inputs of "algorithm". With these restrictions, inference graph
+ // can always be run independently without training information.
+ //
+ // By default, this field is an empty graph and its evaluation does not
+ // produce any output. Evaluating the default training step never
+ // update any initializers.
+ optional GraphProto algorithm = 2;
+
+ // This field specifies the bindings from the outputs of "initialization" to
+ // some initializers in "ModelProto.graph.initializer" and
+ // the "algorithm.initializer" in the same TrainingInfoProto.
+ // See "update_binding" below for details.
+ //
+ // By default, this field is empty and no initializer would be changed
+ // by the execution of "initialization".
+ repeated StringStringEntryProto initialization_binding = 3;
+
+ // Gradient-based training is usually an iterative procedure. In one gradient
+ // descent iteration, we apply
+ //
+ // x = x - r * g
+ //
+ // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
+ // gradient of "x" with respect to a chosen loss. To avoid adding assignments
+ // into the training graph, we split the update equation into
+ //
+ // y = x - r * g
+ // x = y
+ //
+ // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
+ // tell that "y" should be assigned to "x", the field "update_binding" may
+ // contain a key-value pair of strings, "x" (key of StringStringEntryProto)
+ // and "y" (value of StringStringEntryProto).
+ // For a neural network with multiple trainable (mutable) tensors, there can
+ // be multiple key-value pairs in "update_binding".
+ //
+ // The initializers appears as keys in "update_binding" are considered
+ // mutable variables. This implies some behaviors
+ // as described below.
+ //
+ // 1. We have only unique keys in all "update_binding"s so that two
+ // variables may not have the same name. This ensures that one
+ // variable is assigned up to once.
+ // 2. The keys must appear in names of "ModelProto.graph.initializer" or
+ // "TrainingInfoProto.algorithm.initializer".
+ // 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
+ // 4. Mutable variables are initialized to the value specified by the
+ // corresponding initializer, and then potentially updated by
+ // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
+ //
+ // This field usually contains names of trainable tensors
+ // (in ModelProto.graph), optimizer states such as momentums in advanced
+ // stochastic gradient methods (in TrainingInfoProto.graph),
+ // and number of training iterations (in TrainingInfoProto.graph).
+ //
+ // By default, this field is empty and no initializer would be changed
+ // by the execution of "algorithm".
+ repeated StringStringEntryProto update_binding = 4;
+}
+
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
-// The semantics of the model are described by the associated GraphProto.
+// The semantics of the model are described by the associated GraphProto's.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
@@ -227,18 +385,58 @@ message ModelProto {
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
+
+ // Training-specific information. Sequentially executing all stored
+ // `TrainingInfoProto.algorithm`s and assigning their outputs following
+ // the corresponding `TrainingInfoProto.update_binding`s is one training
+ // iteration. Similarly, to initialize the model
+ // (as if training hasn't happened), the user should sequentially execute
+ // all stored `TrainingInfoProto.initialization`s and assigns their outputs
+ // using `TrainingInfoProto.initialization_binding`s.
+ //
+ // If this field is empty, the training behavior of the model is undefined.
+ repeated TrainingInfoProto training_info = 20;
+
+ // A list of function protos local to the model.
+ //
+ // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
+ // In case of any conflicts the behavior (whether the model local functions are given higher priority,
+ // or standard operator sets are given higher priotity or this is treated as error) is defined by
+ // the runtimes.
+ //
+ // The operator sets imported by FunctionProto should be compatible with the ones
+ // imported by ModelProto and other model local FunctionProtos.
+ // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
+ // or by 2 FunctionProtos then versions for the operator set may be different but,
+ // the operator schema returned for op_type, domain, version combination
+ // for both the versions should be same for every node in the function body.
+ //
+ // One FunctionProto can reference other FunctionProto in the model, however, recursive reference
+ // is not allowed.
+ repeated FunctionProto functions = 25;
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
optional string key = 1;
- optional string value= 2;
+ optional string value = 2;
};
+message TensorAnnotation {
+ optional string tensor_name = 1;
+ // <key, value> pairs to annotate tensor specified by <tensor_name> above.
+ // The keys used in the mapping below must be pre-defined in ONNX spec.
+ // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
+ // quantization parameter keys.
+ repeated StringStringEntryProto quant_parameter_tensor_names = 2;
+}
+
+
+
// Graphs
//
-// A graph defines the computational logic of a model and is comprised of a parameterized
+// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
@@ -250,10 +448,14 @@ message GraphProto {
optional string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
- // Each TensorProto entry must have a distinct name (within the list) that
- // also appears in the input list.
+ // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
+ // The name MUST be unique across both initializer and sparse_initializer,
+ // but the name MAY also appear in the input list.
repeated TensorProto initializer = 5;
+ // Initializers (see above) stored in sparse format.
+ repeated SparseTensorProto sparse_initializer = 15;
+
// A human-readable documentation for this graph. Markdown is allowed.
optional string doc_string = 10;
@@ -265,13 +467,14 @@ message GraphProto {
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
- // DO NOT USE the following fields, they were deprecated from earlier versions.
- // repeated string input = 3;
- // repeated string output = 4;
- // optional int64 ir_version = 6;
- // optional int64 producer_version = 7;
- // optional string producer_tag = 8;
- // optional string domain = 9;
+ // This field carries information to indicate the mapping among a tensor and its
+ // quantization parameter tensors. For example:
+ // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
+ // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
+ repeated TensorAnnotation quantization_annotation = 14;
+
+ reserved 3, 4, 6 to 9;
+ reserved "ir_version", "producer_version", "producer_tag", "domain";
}
// Tensors
@@ -291,13 +494,32 @@ message TensorProto {
STRING = 8; // string
BOOL = 9; // bool
- // Advanced types
+ // IEEE754 half-precision floating-point format (16 bits wide).
+ // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
+
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
+
+ // Non-IEEE floating-point format based on IEEE754 single-precision
+ // floating-point number truncated to 16 bits.
+ // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
+ BFLOAT16 = 16;
+
+ // Non-IEEE floating-point format based on papers
+ // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
+ // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
+ // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
+ // The computation usually happens inside a block quantize / dequantize
+ // fused by the runtime.
+ FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
+ FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
+ FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
+ FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
+
// Future extensions go here.
}
@@ -305,7 +527,8 @@ message TensorProto {
repeated int64 dims = 1;
// The data type of the tensor.
- optional DataType data_type = 2;
+ // This field MUST have a valid TensorProto.DataType value
+ optional int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
@@ -324,17 +547,17 @@ message TensorProto {
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
+ // and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
- // For int32, uint8, int8, uint16, int16, bool, and float16 values
- // float16 values must be bit-wise converted to an uint16_t prior
+ // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
+ // float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
- // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
+ // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true];
// For strings.
@@ -371,10 +594,32 @@ message TensorProto {
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
optional bytes raw_data = 9;
+ // Data can be stored inside the protobuf file using type-specific fields or raw_data.
+ // Alternatively, raw bytes data can be stored in an external file, using the external_data field.
+ // external_data stores key-value pairs describing data location. Recognized keys are:
+ // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
+ // protobuf model was stored
+ // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
+ // Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
+ // - "length" (optional) - number of bytes containing data. Integer stored as string.
+ // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
+ repeated StringStringEntryProto external_data = 13;
+
+ // Location of the data for this tensor. MUST be one of:
+ // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
+ // - EXTERNAL - data stored in an external location as described by external_data field.
+ enum DataLocation {
+ DEFAULT = 0;
+ EXTERNAL = 1;
+ }
+
+ // If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
+ optional DataLocation data_location = 14;
+
// For double
- // Complex64 tensors are encoded as a single array of doubles,
+ // Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
+ // and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
@@ -386,6 +631,30 @@ message TensorProto {
repeated uint64 uint64_data = 11 [packed = true];
}
+// A serialized sparse-tensor value
+message SparseTensorProto {
+ // The sequence of non-default values are encoded as a tensor of shape [NNZ].
+ // The default-value is zero for numeric tensors, and empty-string for string tensors.
+ // values must have a non-empty name present which serves as a name for SparseTensorProto
+ // when used in sparse_initializer list.
+ optional TensorProto values = 1;
+
+ // The indices of the non-default values, which may be stored in one of two formats.
+ // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
+ // corresponding to the j-th index of the i-th value (in the values tensor).
+ // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
+ // must be the linearized-index of the i-th value (in the values tensor).
+ // The linearized-index can be converted into an index tuple (k_1,...,k_rank)
+ // using the shape provided below.
+ // The indices must appear in ascending order without duplication.
+ // In the first format, the ordering is lexicographic-ordering:
+ // e.g., index-value [1,4] must appear before [2,1]
+ optional TensorProto indices = 2;
+
+ // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
+ repeated int64 dims = 3;
+}
+
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
@@ -398,36 +667,13 @@ message TensorShapeProto {
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
+ // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
+ // for pre-defined dimension denotations.
optional string denotation = 3;
};
repeated Dimension dim = 1;
}
-// A set of pre-defined constants to be used as values for
-// the standard denotation field in TensorShapeProto.Dimension
-// for semantic description of the tensor dimension.
-message DenotationConstProto {
- // Describe a batch number dimension.
- optional string DATA_BATCH = 1 [default = "DATA_BATCH"];
- // Describe a channel dimension.
- optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"];
- // Describe a time dimension.
- optional string DATA_TIME = 3 [default = "DATA_TIME"];
- // Describe a feature dimension. This is typically a feature
- // dimension in RNN and/or spatial dimension in CNN.
- optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"];
- // Describe a filter in-channel dimension. This is the dimension
- // that is identical (in size) to the channel dimension of the input
- // image feature maps.
- optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"];
- // Describe a filter out channel dimension. This is the dimension
- // that is identical (int size) to the channel dimension of the output
- // image feature maps.
- optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"];
- // Describe a filter spatial dimension.
- optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"];
-}
-
// Types
//
// The standard ONNX data types.
@@ -435,8 +681,43 @@ message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ optional int32 elem_type = 1;
+ optional TensorShapeProto shape = 2;
+ }
+
+ // repeated T
+ message Sequence {
+ // The type and optional shape of each element of the sequence.
+ // This field MUST be present for this version of the IR.
+ optional TypeProto elem_type = 1;
+ };
+
+ // map<K,V>
+ message Map {
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
+ optional int32 key_type = 1;
+ // This field MUST be present for this version of the IR.
+ optional TypeProto value_type = 2;
+ };
+
+ // wrapper for Tensor, Sequence, or Map
+ message Optional {
+ // The type and optional shape of the element wrapped.
+ // This field MUST be present for this version of the IR.
+ // Possible values correspond to OptionalProto.DataType enum
+ optional TypeProto elem_type = 1;
+ };
+
+
+ message SparseTensor {
+ // This field MUST NOT have the value of UNDEFINED
+ // This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
- optional TensorProto.DataType elem_type = 1;
+ optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}
@@ -445,7 +726,31 @@ message TypeProto {
// The type of a tensor.
Tensor tensor_type = 1;
+ // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
+ // as input and output to graphs and nodes. These types are needed to naturally
+ // support classical ML operators. DNN operators SHOULD restrict their input
+ // and output types to tensors.
+
+ // The type of a sequence.
+ Sequence sequence_type = 4;
+
+ // The type of a map.
+ Map map_type = 5;
+
+ // The type of an optional.
+ Optional optional_type = 9;
+
+
+ // Type of the sparse tensor
+ SparseTensor sparse_tensor_type = 8;
+
}
+
+ // An optional denotation can be used to denote the whole
+ // type with a standard semantic description as to what is
+ // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
+ // for pre-defined type denotations.
+ optional string denotation = 6;
}
// Operator Sets
@@ -461,4 +766,70 @@ message OperatorSetIdProto {
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
optional int64 version = 2;
-} \ No newline at end of file
+}
+
+// Operator/function status.
+enum OperatorStatus {
+ EXPERIMENTAL = 0;
+ STABLE = 1;
+}
+
+message FunctionProto {
+ // The name of the function, similar usage of op_type in OperatorProto.
+ // Combined with FunctionProto.domain, this forms the unique identity of
+ // the FunctionProto.
+ optional string name = 1;
+
+ // Deprecated since IR Version 8
+ // optional int64 since_version = 2;
+ reserved 2;
+ reserved "since_version";
+
+ // Deprecated since IR Version 8
+ // optional OperatorStatus status = 3;
+ reserved 3;
+ reserved "status";
+
+ // The inputs and outputs of the function.
+ repeated string input = 4;
+ repeated string output = 5;
+
+ // The attribute parameters of the function.
+ // It is for function parameters without default values.
+ repeated string attribute = 6;
+
+ // The attribute protos of the function.
+ // It is for function attributes with default values.
+ // A function attribute shall be represented either as
+ // a string attribute or an AttributeProto, not both.
+ repeated AttributeProto attribute_proto = 11;
+
+ // The nodes in the function.
+ repeated NodeProto node = 7;
+ // A human-readable documentation for this function. Markdown is allowed.
+ optional string doc_string = 8;
+
+ // The OperatorSets this function body (graph) relies on.
+ //
+ // All nodes in the function body (graph) will bind against the operator
+ // with the same-domain/same-op_type operator with the HIGHEST version
+ // in the referenced operator sets. This means at most one version can be relied
+ // for one domain.
+ //
+ // The operator sets imported by FunctionProto should be compatible with the ones
+ // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
+ // and ModelProto then versions for the operator set may be different but,
+ // the operator schema returned for op_type, domain, version combination
+ // for both the versions should be same.
+
+ repeated OperatorSetIdProto opset_import = 9;
+
+ // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
+ // the FunctionProto.
+ optional string domain = 10;
+}
+
+
+// For using protobuf-lite
+option optimize_for = LITE_RUNTIME;
+
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
index 36f09f989a7..b662179c418 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
@@ -24,7 +24,6 @@ import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.DataplaneToken;
import com.yahoo.config.provision.DockerImage;
import com.yahoo.config.provision.HostName;
-import com.yahoo.config.provision.TenantName;
import com.yahoo.config.provision.Zone;
import com.yahoo.container.jdisc.secretstore.SecretStore;
import com.yahoo.vespa.config.server.tenant.SecretStoreExternalIdRetriever;
@@ -34,6 +33,7 @@ import com.yahoo.vespa.flags.Flags;
import com.yahoo.vespa.flags.PermanentFlags;
import com.yahoo.vespa.flags.StringFlag;
import com.yahoo.vespa.flags.UnboundFlag;
+
import java.io.File;
import java.net.URI;
import java.security.cert.X509Certificate;
@@ -319,13 +319,7 @@ public class ModelContextImpl implements ModelContext {
return flag.bindTo(source)
.with(FetchVector.Dimension.APPLICATION_ID, appId.serializedForm())
.with(FetchVector.Dimension.VESPA_VERSION, vespaVersion.toFullString())
- .boxedValue();
- }
-
- private static <V> V flagValue(FlagSource source, TenantName tenant, Version vespaVersion, UnboundFlag<? extends V, ?, ?> flag) {
- return flag.bindTo(source)
- .with(FetchVector.Dimension.TENANT_ID, tenant.value())
- .with(FetchVector.Dimension.VESPA_VERSION, vespaVersion.toFullString())
+ .with(FetchVector.Dimension.TENANT_ID, appId.tenant().value())
.boxedValue();
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java
index 0532a81617f..b2762b2a3d4 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java
@@ -98,7 +98,7 @@ public class ApplicationApiHandler extends SessionHandler {
"Unable to parse multipart in deploy from tenant '" + tenantName.value() + "': " + Exceptions.toMessageString(e));
var message = "Deploy request from '" + tenantName.value() + "' contains invalid data: " + e.getMessage();
- log.log(INFO, message + ", parts: " + parts, e);
+ log.log(FINE, message + ", parts: " + parts, e);
throw new BadRequestException("Deploy request from '" + tenantName.value() + "' contains invalid data: " + e.getMessage());
}
} else {
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java b/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java
index e6af65c0bc8..47050168b80 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java
@@ -36,8 +36,8 @@ public class DataplaneProxyService extends AbstractComponent {
private final Path root;
enum NginxState {INITIALIZING, RUNNING, RELOAD_REQUIRED, STOPPED};
- private NginxState state;
- private NginxState wantedState;
+ private volatile NginxState state;
+ private volatile NginxState wantedState;
private DataplaneProxyConfig cfg;
private Path proxyCredentialsCert;
@@ -113,35 +113,46 @@ public class DataplaneProxyService extends AbstractComponent {
throw new RuntimeException("Error reconfiguring data plane proxy", e);
}
}
- if (wantedState == NginxState.RUNNING) {
+ NginxState convergeTo = wantedState;
+ if (convergeTo == NginxState.RUNNING) {
boolean nginxRunning = proxyCommands.isRunning();
if (!nginxRunning) {
try {
proxyCommands.start(nginxConf);
- changeState(wantedState);
+ changeState(convergeTo);
} catch (Exception e) {
logger.log(Level.INFO, "Failed to start nginx, will retry");
+ logger.log(Level.FINE, "Exception from nginx start", e);
}
- } else if (nginxRunning && state == NginxState.RELOAD_REQUIRED) {
- try {
- proxyCommands.reload();
- changeState(wantedState);
- } catch (Exception e) {
- logger.log(Level.INFO, "Failed to reconfigure nginx, will retry.");
+ } else {
+ if (state == NginxState.RELOAD_REQUIRED) {
+ try {
+ proxyCommands.reload();
+ changeState(convergeTo);
+ } catch (Exception e) {
+ logger.log(Level.INFO, "Failed to reconfigure nginx, will retry.");
+ logger.log(Level.FINE, "Exception from nginx reload", e);
+ }
+ } else if (state != convergeTo) {
+ // Already running, but state not updated
+ changeState(convergeTo);
}
}
- } else if (wantedState == NginxState.STOPPED) {
+ } else if (convergeTo == NginxState.STOPPED) {
if (proxyCommands.isRunning()) {
try {
proxyCommands.stop();
- changeState(wantedState);
- executorService.shutdownNow();
} catch (Exception e) {
logger.log(Level.INFO, "Failed to stop nginx, will retry");
+ logger.log(Level.FINE, "Exception from nginx stop", e);
}
}
+ if (! proxyCommands.isRunning()) {
+ changeState(convergeTo);
+ executorService.shutdownNow();
+ }
} else {
- logger.warning("Unknown state " + wantedState);
+ logger.warning("Unknown state " + convergeTo);
}
}
@@ -150,9 +161,9 @@ public class DataplaneProxyService extends AbstractComponent {
super.deconstruct();
wantedState = NginxState.STOPPED;
try {
- executorService.awaitTermination(5, TimeUnit.MINUTES);
+ executorService.awaitTermination(30, TimeUnit.SECONDS);
} catch (InterruptedException e) {
- logger.log(Level.WARNING, "Error shutting down proxy reload thread");
+ logger.log(Level.WARNING, "Error shutting down proxy reload thread", e);
}
}
@@ -203,10 +214,12 @@ public class DataplaneProxyService extends AbstractComponent {
return template.replaceAll("\\$\\{%s\\}".formatted(key), value);
}
+ // Used for testing
NginxState state() {
return state;
}
+ // Used for testing
NginxState wantedState() {
return wantedState;
}
diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java
index 947c99adf51..351890e2a3a 100644
--- a/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java
+++ b/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java
@@ -22,13 +22,16 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class DataplaneProxyServiceTest {
private FileSystem fileSystem = Jimfs.newFileSystem();
- DataplaneProxyService.ProxyCommands proxyCommandsMock = Mockito.mock(DataplaneProxyService.ProxyCommands.class);
+ DataplaneProxyService.ProxyCommands proxyCommandsMock = mock(DataplaneProxyService.ProxyCommands.class);
@Test
public void starts_and_reloads_if_no_errors() throws IOException {
@@ -122,6 +125,35 @@ public class DataplaneProxyServiceTest {
assertFalse(proxyCommands.isRunning());
}
+ @Test
+ public void stops_executor_when_nginx_stop_throws() throws IOException, InterruptedException {
+ DataplaneProxyService.ProxyCommands mockProxyCommands = mock(DataplaneProxyService.ProxyCommands.class);
+ DataplaneProxyService service = dataplaneProxyService(mockProxyCommands);
+ service.converge();
+ when (mockProxyCommands.isRunning()).thenReturn(true);
+ assertEquals(DataplaneProxyService.NginxState.RUNNING, service.state());
+
+ reset(proxyCommandsMock);
+
+ when(mockProxyCommands.isRunning()).thenReturn(true).thenReturn(false);
+ doThrow(new RuntimeException("Failed to stop proxy")).when(proxyCommandsMock).stop();
+ Thread thread = new Thread(service::deconstruct);// deconstruct will block until nginx is stopped
+ thread.start();
+
+ // Wait for above thread to set the wanted state to STOPPED
+ while (service.wantedState() != DataplaneProxyService.NginxState.STOPPED) {
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ }
+ }
+ service.converge();
+ assertEquals(service.state(), DataplaneProxyService.NginxState.STOPPED);
+ thread.join();
+
+ verify(mockProxyCommands, times(1)).stop();
+ }
+
private DataplaneProxyService dataplaneProxyService(DataplaneProxyService.ProxyCommands proxyCommands) throws IOException {
Path root = fileSystem.getPath("/opt/vespa");
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java
index ffaee34e727..d73a7410cc6 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java
@@ -26,18 +26,21 @@ public class InstanceInformation {
public URI url;
public String scope;
public RoutingMethod routingMethod;
+ public String auth;
@JsonCreator
public Endpoint(@JsonProperty("cluster") String cluster ,
@JsonProperty("tls") boolean tls,
@JsonProperty("url") URI url,
@JsonProperty("scope") String scope,
- @JsonProperty("routingMethod") RoutingMethod routingMethod) {
+ @JsonProperty("routingMethod") RoutingMethod routingMethod,
+ @JsonProperty("authMethod") String auth) {
this.cluster = cluster;
this.tls = tls;
this.url = url;
this.scope = scope;
this.routingMethod = routingMethod;
+ this.auth = auth;
}
@Override
@@ -47,6 +50,7 @@ public class InstanceInformation {
", tls=" + tls +
", url=" + url +
", scope='" + scope + '\'' +
+ ", authType='" + auth + '\'' +
", routingMethod=" + routingMethod +
'}';
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java
index 0f3f9479176..68852f90055 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java
@@ -129,7 +129,7 @@ public class EndpointCertificates {
}
private Optional<EndpointCertificateMetadata> getOrProvision(Instance instance, ZoneId zone, DeploymentSpec deploymentSpec) {
- if (useRandomizedCert.with(FetchVector.Dimension.APPLICATION_ID, instance.id().toFullString()).value()) {
+ if (useRandomizedCert.with(FetchVector.Dimension.APPLICATION_ID, instance.id().serializedForm()).value()) {
return Optional.of(assignFromPool(instance, zone));
}
Optional<AssignedCertificate> assignedCertificate = curator.readAssignedCertificate(TenantAndApplicationId.from(instance.id()), Optional.of(instance.id().instance()));
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
index aa3f78f1395..693275987c5 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
@@ -139,7 +139,6 @@ import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
-import java.io.UncheckedIOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
@@ -911,14 +910,17 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler {
}
private HttpResponse listTokens(String tenant, HttpRequest request) {
- List<DataplaneTokenVersions> dataplaneTokenVersions = controller.dataplaneTokenService().listTokens(TenantName.from(tenant));
+ var tokens = controller.dataplaneTokenService().listTokens(TenantName.from(tenant))
+ .stream().sorted(Comparator.comparing(DataplaneTokenVersions::tokenId)).toList();
Slime slime = new Slime();
Cursor tokensArray = slime.setObject().setArray("tokens");
- for (DataplaneTokenVersions token : dataplaneTokenVersions) {
+ for (DataplaneTokenVersions token : tokens) {
Cursor tokenObject = tokensArray.addObject();
tokenObject.setString("id", token.tokenId().value());
Cursor fingerprintsArray = tokenObject.setArray("versions");
- for (DataplaneTokenVersions.Version tokenVersion : token.tokenVersions()) {
+ var versions = token.tokenVersions().stream()
+ .sorted(Comparator.comparing(DataplaneTokenVersions.Version::creationTime)).toList();
+ for (var tokenVersion : versions) {
Cursor fingerprintObject = fingerprintsArray.addObject();
fingerprintObject.setString("fingerprint", tokenVersion.fingerPrint().value());
fingerprintObject.setString("created", tokenVersion.creationTime().toString());
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
index f12f60dcc8e..f690b8e8c8a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
@@ -32,8 +32,9 @@ class TensorConverter {
}
private static Values readValuesOf(Onnx.TensorProto tensorProto) {
+ var elemType = Onnx.TensorProto.DataType.forNumber(tensorProto.getDataType());
if (tensorProto.hasRawData()) {
- switch (tensorProto.getDataType()) {
+ switch (elemType) {
case BOOL: return new RawBoolValues(tensorProto);
case FLOAT: return new RawFloatValues(tensorProto);
case DOUBLE: return new RawDoubleValues(tensorProto);
@@ -41,7 +42,7 @@ class TensorConverter {
case INT64: return new RawLongValues(tensorProto);
}
} else {
- switch (tensorProto.getDataType()) {
+ switch (elemType) {
case FLOAT: return new FloatValues(tensorProto);
case DOUBLE: return new DoubleValues(tensorProto);
case INT32: return new IntValues(tensorProto);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 35ec1d8c54a..deac950d324 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -37,7 +37,8 @@ class TypeConverter {
static OrderedTensorType typeFrom(Onnx.TypeProto type) {
String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(type.getTensorType().getElemType()));
+ var elemType = Onnx.TensorProto.DataType.forNumber(type.getTensorType().getElemType());
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(elemType));
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
@@ -52,8 +53,8 @@ class TypeConverter {
}
static OrderedTensorType typeFrom(Onnx.TensorProto tensor) {
- return OrderedTensorType.fromDimensionList(toValueType(tensor.getDataType()),
- tensor.getDimsList());
+ var elemType = Onnx.TensorProto.DataType.forNumber(tensor.getDataType());
+ return OrderedTensorType.fromDimensionList(toValueType(elemType), tensor.getDimsList());
}
private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
diff --git a/model-integration/src/main/protobuf/onnx.proto b/model-integration/src/main/protobuf/onnx.proto
index dc6542867e0..1d265ae9f28 100644
--- a/model-integration/src/main/protobuf/onnx.proto
+++ b/model-integration/src/main/protobuf/onnx.proto
@@ -3,8 +3,8 @@
//
-// Copyright (c) Facebook Inc. and Microsoft Corporation.
-// Licensed under the MIT license.
+// SPDX-License-Identifier: Apache-2.0
+
syntax = "proto2";
@@ -20,23 +20,16 @@ package onnx;
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
-// Intermediate Representation, or 'IR' for short.
+// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
-// Release
-//
-// We are still in the very early stage of defining ONNX. The current
-// version of ONNX is a starting point. While we are actively working
-// towards a complete spec, we would like to get the community involved
-// by sharing our working version of ONNX.
-//
// Protobuf compatibility
-//
-// To simplify framework compatibility, ONNX is defined using the subset of protobuf
+//
+// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
@@ -60,22 +53,60 @@ enum Version {
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
- // control. We should use version as
- // xx(major) - xx(minor) - xxxx(bugfix)
- // and we are starting with 0x00000001 (0.0.1), which was the
- // version we published on Oct 10, 2017.
- IR_VERSION_2017_10_10 = 0x00000001;
+ // control.
+ // For the IR, we are using simple numbers starting with 0x00000001,
+ // which was the version we published on Oct 10, 2017.
+ IR_VERSION_2017_10_10 = 0x0000000000000001;
- // IR_VERSION 0.0.2 published on Oct 30, 2017
+ // IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
- IR_VERSION_2017_10_30 = 0x00000002;
+ IR_VERSION_2017_10_30 = 0x0000000000000002;
- // IR VERSION 0.0.3 published on Nov 3, 2017
+ // IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
- IR_VERSION = 0x00000003;
+ IR_VERSION_2017_11_3 = 0x0000000000000003;
+
+ // IR VERSION 4 published on Jan 22, 2019
+ // - Relax constraint that initializers should be a subset of graph inputs
+ // - Add type BFLOAT16
+ IR_VERSION_2019_1_22 = 0x0000000000000004;
+
+ // IR VERSION 5 published on March 18, 2019
+ // - Add message TensorAnnotation.
+ // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
+ IR_VERSION_2019_3_18 = 0x0000000000000005;
+
+ // IR VERSION 6 published on Sep 19, 2019
+ // - Add support for sparse tensor constants stored in model.
+ // - Add message SparseTensorProto
+ // - Add sparse initializers
+ IR_VERSION_2019_9_19 = 0x0000000000000006;
+
+ // IR VERSION 7 published on May 8, 2020
+ // - Add support to allow function body graph to rely on multiple external opreator sets.
+ // - Add a list to promote inference graph's initializers to global and
+ // mutable variables. Global variables are visible in all graphs of the
+ // stored models.
+ // - Add message TrainingInfoProto to store initialization
+ // method and training algorithm. The execution of TrainingInfoProto
+ // can modify the values of mutable variables.
+ // - Implicitly add inference graph into each TrainingInfoProto's algorithm.
+ IR_VERSION_2020_5_8 = 0x0000000000000007;
+
+ // IR VERSION 8 published on July 30, 2021
+ // Introduce TypeProto.SparseTensor
+ // Introduce TypeProto.Optional
+ // Added a list of FunctionProtos local to the model
+ // Deprecated since_version and operator status from FunctionProto
+ IR_VERSION_2021_7_30 = 0x0000000000000008;
+
+ // IR VERSION 9 published on May 5, 2023
+ // Added AttributeProto to FunctionProto so that default attribute values can be set.
+ // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
+ IR_VERSION = 0x0000000000000009;
}
// Attributes
@@ -95,17 +126,21 @@ message AttributeProto {
STRING = 3;
TENSOR = 4;
GRAPH = 5;
+ SPARSE_TENSOR = 11;
+ TYPE_PROTO = 13;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
+ SPARSE_TENSORS = 12;
+ TYPE_PROTOS = 14;
}
// The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute
-
+
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
@@ -117,10 +152,10 @@ message AttributeProto {
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
- // implementations needed to use has_field hueristics to determine
+ // implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
- // change was made to accomodate proto3 implementations.
+ // change was made to accommodate proto3 implementations.
optional AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
@@ -129,14 +164,18 @@ message AttributeProto {
optional bytes s = 4; // UTF-8 string
optional TensorProto t = 5; // tensor value
optional GraphProto g = 6; // graph
+ optional SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
+ optional TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
+ repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
+ repeated TypeProto type_protos = 15;// list of type protos
}
// Defines information on value, including the name, the type, and
@@ -144,7 +183,8 @@ message AttributeProto {
message ValueInfoProto {
// This field MUST be present in this version of the IR.
optional string name = 1; // namespace Value
- // This field MUST be present in this version of the IR.
+ // This field MUST be present in this version of the IR for
+ // inputs and outputs of the top-level graph.
optional TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
optional string doc_string = 3;
@@ -155,7 +195,7 @@ message ValueInfoProto {
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
-// For example, it can be a node of type "Conv" that takes in an image, a filter
+// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
@@ -177,12 +217,130 @@ message NodeProto {
optional string doc_string = 6;
}
+// Training information
+// TrainingInfoProto stores information for training a model.
+// In particular, this defines two functionalities: an initialization-step
+// and a training-algorithm-step. Initialization resets the model
+// back to its original state as if no training has been performed.
+// Training algorithm improves the model based on input data.
+//
+// The semantics of the initialization-step is that the initializers
+// in ModelProto.graph and in TrainingInfoProto.algorithm are first
+// initialized as specified by the initializers in the graph, and then
+// updated by the "initialization_binding" in every instance in
+// ModelProto.training_info.
+//
+// The field "algorithm" defines a computation graph which represents a
+// training algorithm's step. After the execution of a
+// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
+// may be immediately updated. If the targeted training algorithm contains
+// consecutive update steps (such as block coordinate descent methods),
+// the user needs to create a TrainingInfoProto for each step.
+message TrainingInfoProto {
+ // This field describes a graph to compute the initial tensors
+ // upon starting the training process. Initialization graph has no input
+ // and can have multiple outputs. Usually, trainable tensors in neural
+ // networks are randomly initialized. To achieve that, for each tensor,
+ // the user can put a random number operator such as RandomNormal or
+ // RandomUniform in TrainingInfoProto.initialization.node and assign its
+ // random output to the specific tensor using "initialization_binding".
+ // This graph can also set the initializers in "algorithm" in the same
+ // TrainingInfoProto; a use case is resetting the number of training
+ // iteration to zero.
+ //
+ // By default, this field is an empty graph and its evaluation does not
+ // produce any output. Thus, no initializer would be changed by default.
+ optional GraphProto initialization = 1;
+
+ // This field represents a training algorithm step. Given required inputs,
+ // it computes outputs to update initializers in its own or inference graph's
+ // initializer lists. In general, this field contains loss node, gradient node,
+ // optimizer node, increment of iteration count.
+ //
+ // An execution of the training algorithm step is performed by executing the
+ // graph obtained by combining the inference graph (namely "ModelProto.graph")
+ // and the "algorithm" graph. That is, the actual
+ // input/initializer/output/node/value_info/sparse_initializer list of
+ // the training graph is the concatenation of
+ // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
+ // and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
+ // in that order. This combined graph must satisfy the normal ONNX conditions.
+ // Now, let's provide a visualization of graph combination for clarity.
+ // Let the inference graph (i.e., "ModelProto.graph") be
+ // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
+ // and the "algorithm" graph be
+ // tensor_d -> Add -> tensor_e
+ // The combination process results
+ // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
+ //
+ // Notice that an input of a node in the "algorithm" graph may reference the
+ // output of a node in the inference graph (but not the other way round). Also, inference
+ // node cannot reference inputs of "algorithm". With these restrictions, inference graph
+ // can always be run independently without training information.
+ //
+ // By default, this field is an empty graph and its evaluation does not
+ // produce any output. Evaluating the default training step never
+ // update any initializers.
+ optional GraphProto algorithm = 2;
+
+ // This field specifies the bindings from the outputs of "initialization" to
+ // some initializers in "ModelProto.graph.initializer" and
+ // the "algorithm.initializer" in the same TrainingInfoProto.
+ // See "update_binding" below for details.
+ //
+ // By default, this field is empty and no initializer would be changed
+ // by the execution of "initialization".
+ repeated StringStringEntryProto initialization_binding = 3;
+
+ // Gradient-based training is usually an iterative procedure. In one gradient
+ // descent iteration, we apply
+ //
+ // x = x - r * g
+ //
+ // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
+ // gradient of "x" with respect to a chosen loss. To avoid adding assignments
+ // into the training graph, we split the update equation into
+ //
+ // y = x - r * g
+ // x = y
+ //
+ // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
+ // tell that "y" should be assigned to "x", the field "update_binding" may
+ // contain a key-value pair of strings, "x" (key of StringStringEntryProto)
+ // and "y" (value of StringStringEntryProto).
+ // For a neural network with multiple trainable (mutable) tensors, there can
+ // be multiple key-value pairs in "update_binding".
+ //
+ // The initializers appears as keys in "update_binding" are considered
+ // mutable variables. This implies some behaviors
+ // as described below.
+ //
+ // 1. We have only unique keys in all "update_binding"s so that two
+ // variables may not have the same name. This ensures that one
+ // variable is assigned up to once.
+ // 2. The keys must appear in names of "ModelProto.graph.initializer" or
+ // "TrainingInfoProto.algorithm.initializer".
+ // 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
+ // 4. Mutable variables are initialized to the value specified by the
+ // corresponding initializer, and then potentially updated by
+ // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
+ //
+ // This field usually contains names of trainable tensors
+ // (in ModelProto.graph), optimizer states such as momentums in advanced
+ // stochastic gradient methods (in TrainingInfoProto.graph),
+ // and number of training iterations (in TrainingInfoProto.graph).
+ //
+ // By default, this field is empty and no initializer would be changed
+ // by the execution of "algorithm".
+ repeated StringStringEntryProto update_binding = 4;
+}
+
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
-// The semantics of the model are described by the associated GraphProto.
+// The semantics of the model are described by the associated GraphProto's.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
@@ -227,18 +385,58 @@ message ModelProto {
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
+
+ // Training-specific information. Sequentially executing all stored
+ // `TrainingInfoProto.algorithm`s and assigning their outputs following
+ // the corresponding `TrainingInfoProto.update_binding`s is one training
+ // iteration. Similarly, to initialize the model
+ // (as if training hasn't happened), the user should sequentially execute
+ // all stored `TrainingInfoProto.initialization`s and assigns their outputs
+ // using `TrainingInfoProto.initialization_binding`s.
+ //
+ // If this field is empty, the training behavior of the model is undefined.
+ repeated TrainingInfoProto training_info = 20;
+
+ // A list of function protos local to the model.
+ //
+ // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
+ // In case of any conflicts the behavior (whether the model local functions are given higher priority,
+ // or standard operator sets are given higher priotity or this is treated as error) is defined by
+ // the runtimes.
+ //
+ // The operator sets imported by FunctionProto should be compatible with the ones
+ // imported by ModelProto and other model local FunctionProtos.
+ // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
+ // or by 2 FunctionProtos then versions for the operator set may be different but,
+ // the operator schema returned for op_type, domain, version combination
+ // for both the versions should be same for every node in the function body.
+ //
+ // One FunctionProto can reference other FunctionProto in the model, however, recursive reference
+ // is not allowed.
+ repeated FunctionProto functions = 25;
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
optional string key = 1;
- optional string value= 2;
+ optional string value = 2;
};
+message TensorAnnotation {
+ optional string tensor_name = 1;
+ // <key, value> pairs to annotate tensor specified by <tensor_name> above.
+ // The keys used in the mapping below must be pre-defined in ONNX spec.
+ // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
+ // quantization parameter keys.
+ repeated StringStringEntryProto quant_parameter_tensor_names = 2;
+}
+
+
+
// Graphs
//
-// A graph defines the computational logic of a model and is comprised of a parameterized
+// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
@@ -250,10 +448,14 @@ message GraphProto {
optional string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
- // Each TensorProto entry must have a distinct name (within the list) that
- // also appears in the input list.
+ // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
+ // The name MUST be unique across both initializer and sparse_initializer,
+ // but the name MAY also appear in the input list.
repeated TensorProto initializer = 5;
+ // Initializers (see above) stored in sparse format.
+ repeated SparseTensorProto sparse_initializer = 15;
+
// A human-readable documentation for this graph. Markdown is allowed.
optional string doc_string = 10;
@@ -265,13 +467,14 @@ message GraphProto {
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
- // DO NOT USE the following fields, they were deprecated from earlier versions.
- // repeated string input = 3;
- // repeated string output = 4;
- // optional int64 ir_version = 6;
- // optional int64 producer_version = 7;
- // optional string producer_tag = 8;
- // optional string domain = 9;
+ // This field carries information to indicate the mapping among a tensor and its
+ // quantization parameter tensors. For example:
+ // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
+ // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
+ repeated TensorAnnotation quantization_annotation = 14;
+
+ reserved 3, 4, 6 to 9;
+ reserved "ir_version", "producer_version", "producer_tag", "domain";
}
// Tensors
@@ -291,13 +494,32 @@ message TensorProto {
STRING = 8; // string
BOOL = 9; // bool
- // Advanced types
+ // IEEE754 half-precision floating-point format (16 bits wide).
+ // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
+
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
+
+ // Non-IEEE floating-point format based on IEEE754 single-precision
+ // floating-point number truncated to 16 bits.
+ // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
+ BFLOAT16 = 16;
+
+ // Non-IEEE floating-point format based on papers
+ // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
+ // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
+ // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
+ // The computation usually happens inside a block quantize / dequantize
+ // fused by the runtime.
+ FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
+ FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
+ FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
+ FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
+
// Future extensions go here.
}
@@ -305,7 +527,8 @@ message TensorProto {
repeated int64 dims = 1;
// The data type of the tensor.
- optional DataType data_type = 2;
+ // This field MUST have a valid TensorProto.DataType value
+ optional int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
@@ -324,17 +547,17 @@ message TensorProto {
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
+ // and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
- // For int32, uint8, int8, uint16, int16, bool, and float16 values
- // float16 values must be bit-wise converted to an uint16_t prior
+ // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
+ // float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
- // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
+ // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true];
// For strings.
@@ -371,10 +594,32 @@ message TensorProto {
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
optional bytes raw_data = 9;
+ // Data can be stored inside the protobuf file using type-specific fields or raw_data.
+ // Alternatively, raw bytes data can be stored in an external file, using the external_data field.
+ // external_data stores key-value pairs describing data location. Recognized keys are:
+ // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
+ // protobuf model was stored
+ // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
+ // Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
+ // - "length" (optional) - number of bytes containing data. Integer stored as string.
+ // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
+ repeated StringStringEntryProto external_data = 13;
+
+ // Location of the data for this tensor. MUST be one of:
+ // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
+ // - EXTERNAL - data stored in an external location as described by external_data field.
+ enum DataLocation {
+ DEFAULT = 0;
+ EXTERNAL = 1;
+ }
+
+ // If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
+ optional DataLocation data_location = 14;
+
// For double
- // Complex64 tensors are encoded as a single array of doubles,
+ // Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
+ // and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
@@ -386,6 +631,30 @@ message TensorProto {
repeated uint64 uint64_data = 11 [packed = true];
}
+// A serialized sparse-tensor value
+message SparseTensorProto {
+ // The sequence of non-default values are encoded as a tensor of shape [NNZ].
+ // The default-value is zero for numeric tensors, and empty-string for string tensors.
+ // values must have a non-empty name present which serves as a name for SparseTensorProto
+ // when used in sparse_initializer list.
+ optional TensorProto values = 1;
+
+ // The indices of the non-default values, which may be stored in one of two formats.
+ // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
+ // corresponding to the j-th index of the i-th value (in the values tensor).
+ // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
+ // must be the linearized-index of the i-th value (in the values tensor).
+ // The linearized-index can be converted into an index tuple (k_1,...,k_rank)
+ // using the shape provided below.
+ // The indices must appear in ascending order without duplication.
+ // In the first format, the ordering is lexicographic-ordering:
+ // e.g., index-value [1,4] must appear before [2,1]
+ optional TensorProto indices = 2;
+
+ // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
+ repeated int64 dims = 3;
+}
+
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
@@ -398,36 +667,13 @@ message TensorShapeProto {
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
+ // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
+ // for pre-defined dimension denotations.
optional string denotation = 3;
};
repeated Dimension dim = 1;
}
-// A set of pre-defined constants to be used as values for
-// the standard denotation field in TensorShapeProto.Dimension
-// for semantic description of the tensor dimension.
-message DenotationConstProto {
- // Describe a batch number dimension.
- optional string DATA_BATCH = 1 [default = "DATA_BATCH"];
- // Describe a channel dimension.
- optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"];
- // Describe a time dimension.
- optional string DATA_TIME = 3 [default = "DATA_TIME"];
- // Describe a feature dimension. This is typically a feature
- // dimension in RNN and/or spatial dimension in CNN.
- optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"];
- // Describe a filter in-channel dimension. This is the dimension
- // that is identical (in size) to the channel dimension of the input
- // image feature maps.
- optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"];
- // Describe a filter out channel dimension. This is the dimension
- // that is identical (int size) to the channel dimension of the output
- // image feature maps.
- optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"];
- // Describe a filter spatial dimension.
- optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"];
-}
-
// Types
//
// The standard ONNX data types.
@@ -435,8 +681,43 @@ message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ optional int32 elem_type = 1;
+ optional TensorShapeProto shape = 2;
+ }
+
+ // repeated T
+ message Sequence {
+ // The type and optional shape of each element of the sequence.
+ // This field MUST be present for this version of the IR.
+ optional TypeProto elem_type = 1;
+ };
+
+ // map<K,V>
+ message Map {
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
+ optional int32 key_type = 1;
+ // This field MUST be present for this version of the IR.
+ optional TypeProto value_type = 2;
+ };
+
+ // wrapper for Tensor, Sequence, or Map
+ message Optional {
+ // The type and optional shape of the element wrapped.
+ // This field MUST be present for this version of the IR.
+ // Possible values correspond to OptionalProto.DataType enum
+ optional TypeProto elem_type = 1;
+ };
+
+
+ message SparseTensor {
+ // This field MUST NOT have the value of UNDEFINED
+ // This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
- optional TensorProto.DataType elem_type = 1;
+ optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}
@@ -445,7 +726,31 @@ message TypeProto {
// The type of a tensor.
Tensor tensor_type = 1;
+ // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
+ // as input and output to graphs and nodes. These types are needed to naturally
+ // support classical ML operators. DNN operators SHOULD restrict their input
+ // and output types to tensors.
+
+ // The type of a sequence.
+ Sequence sequence_type = 4;
+
+ // The type of a map.
+ Map map_type = 5;
+
+ // The type of an optional.
+ Optional optional_type = 9;
+
+
+ // Type of the sparse tensor
+ SparseTensor sparse_tensor_type = 8;
+
}
+
+ // An optional denotation can be used to denote the whole
+ // type with a standard semantic description as to what is
+ // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
+ // for pre-defined type denotations.
+ optional string denotation = 6;
}
// Operator Sets
@@ -461,4 +766,70 @@ message OperatorSetIdProto {
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
optional int64 version = 2;
-} \ No newline at end of file
+}
+
+// Operator/function status.
+enum OperatorStatus {
+ EXPERIMENTAL = 0;
+ STABLE = 1;
+}
+
+message FunctionProto {
+ // The name of the function, similar usage of op_type in OperatorProto.
+ // Combined with FunctionProto.domain, this forms the unique identity of
+ // the FunctionProto.
+ optional string name = 1;
+
+ // Deprecated since IR Version 8
+ // optional int64 since_version = 2;
+ reserved 2;
+ reserved "since_version";
+
+ // Deprecated since IR Version 8
+ // optional OperatorStatus status = 3;
+ reserved 3;
+ reserved "status";
+
+ // The inputs and outputs of the function.
+ repeated string input = 4;
+ repeated string output = 5;
+
+ // The attribute parameters of the function.
+ // It is for function parameters without default values.
+ repeated string attribute = 6;
+
+ // The attribute protos of the function.
+ // It is for function attributes with default values.
+ // A function attribute shall be represented either as
+ // a string attribute or an AttributeProto, not both.
+ repeated AttributeProto attribute_proto = 11;
+
+ // The nodes in the function.
+ repeated NodeProto node = 7;
+ // A human-readable documentation for this function. Markdown is allowed.
+ optional string doc_string = 8;
+
+ // The OperatorSets this function body (graph) relies on.
+ //
+ // All nodes in the function body (graph) will bind against the operator
+ // with the same-domain/same-op_type operator with the HIGHEST version
+ // in the referenced operator sets. This means at most one version can be relied
+ // for one domain.
+ //
+ // The operator sets imported by FunctionProto should be compatible with the ones
+ // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
+ // and ModelProto then versions for the operator set may be different but,
+ // the operator schema returned for op_type, domain, version combination
+ // for both the versions should be same.
+
+ repeated OperatorSetIdProto opset_import = 9;
+
+ // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
+ // the FunctionProto.
+ optional string domain = 10;
+}
+
+
+// For using protobuf-lite
+option optimize_for = LITE_RUNTIME;
+
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index 3ef96cdf166..2b707c3beb3 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -775,10 +775,10 @@ public class OnnxOperationsTestCase {
Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder();
tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get()));
if (tensor.type().valueType() == TensorType.Value.FLOAT) {
- builder.setDataType(Onnx.TensorProto.DataType.FLOAT);
+ builder.setDataType(Onnx.TensorProto.DataType.FLOAT_VALUE);
tensor.valueIterator().forEachRemaining(d -> builder.addFloatData(d.floatValue()));
} else {
- builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);
+ builder.setDataType(Onnx.TensorProto.DataType.DOUBLE_VALUE);
tensor.valueIterator().forEachRemaining(builder::addDoubleData);
}
Onnx.TensorProto val = builder.build();
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java
index dc86daf2c67..9bc18533ddf 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java
@@ -17,9 +17,16 @@ import java.util.Objects;
*/
public final class LockedNodeList extends NodeList {
+ private final Mutex lock;
+
public LockedNodeList(List<Node> nodes, Mutex lock) {
super(nodes, false);
- Objects.requireNonNull(lock, "lock must be non-null");
+ this.lock = Objects.requireNonNull(lock, "lock must be non-null");
+ }
+
+ /** Returns a new LockedNodeList with the for the same lock. */
+ public LockedNodeList childList(List<Node> nodes) {
+ return new LockedNodeList(nodes, lock);
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java
index 60fd07951c6..20c246b3ebd 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java
@@ -28,4 +28,5 @@ public class NodeMutex implements Mutex {
return new NodeMutex(updatedNode, mutex);
}
+
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java
index d6671d41cbd..9da66413b9c 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java
@@ -229,7 +229,7 @@ public class NodeRepository extends AbstractComponent {
applicationNodes.asList(),
Agent.system,
Optional.of("Application is removed"),
- transaction.nested());
+ transaction);
applications.remove(transaction);
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java
index 8766dea3d61..e300591fbb2 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java
@@ -3,6 +3,8 @@ package com.yahoo.vespa.hosted.provision.maintenance;
import com.yahoo.jdisc.Metric;
import com.yahoo.vespa.hosted.provision.Node;
+import com.yahoo.vespa.hosted.provision.Node.State;
+import com.yahoo.vespa.hosted.provision.NodeList;
import com.yahoo.vespa.hosted.provision.NodeRepository;
import com.yahoo.vespa.hosted.provision.node.Agent;
import com.yahoo.vespa.hosted.provision.node.History;
@@ -33,8 +35,12 @@ public class DirtyExpirer extends Expirer {
@Override
protected void expire(List<Node> expired) {
- for (Node expiredNode : expired)
- nodeRepository().nodes().fail(expiredNode.hostname(), wantToDeprovisionOnExpiry, Agent.DirtyExpirer, "Node is stuck in dirty");
+ nodeRepository().nodes().performOn(NodeList.copyOf(expired),
+ node -> node.state() == State.dirty && isExpired(node),
+ (node, lock) -> nodeRepository().nodes().fail(node.hostname(),
+ wantToDeprovisionOnExpiry,
+ Agent.DirtyExpirer,
+ "Node is stuck in dirty"));
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java
index fa3f9435c70..cb0a8005e87 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java
@@ -6,13 +6,14 @@ import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.Zone;
import com.yahoo.jdisc.Metric;
import com.yahoo.vespa.hosted.provision.Node;
+import com.yahoo.vespa.hosted.provision.Node.State;
import com.yahoo.vespa.hosted.provision.NodeList;
+import com.yahoo.vespa.hosted.provision.NodeMutex;
import com.yahoo.vespa.hosted.provision.NodeRepository;
import com.yahoo.vespa.hosted.provision.node.Agent;
-import com.yahoo.vespa.hosted.provision.node.History;
+import com.yahoo.vespa.hosted.provision.node.History.Event.Type;
import java.time.Duration;
-import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
@@ -67,55 +68,47 @@ public class FailedExpirer extends NodeRepositoryMaintainer {
@Override
protected double maintain() {
- NodeList allNodes = nodeRepository.nodes().list();
- List<Node> remainingNodes = new ArrayList<>(allNodes.state(Node.State.failed)
- .nodeType(NodeType.tenant, NodeType.host)
- .asList());
+ Predicate<Node> isExpired = node -> node.state() == State.failed
+ && node.history().hasEventBefore(Type.failed, clock().instant().minus(expiryFor(node)));
+ NodeList allNodes = nodeRepository.nodes().list(); // Stale snapshot, not critical.
- recycleIf(node -> node.allocation().isEmpty(), remainingNodes, allNodes);
- recycleIf(node -> !node.allocation().get().membership().cluster().isStateful() &&
- node.history().hasEventBefore(History.Event.Type.failed, clock().instant().minus(statelessExpiry)),
- remainingNodes,
- allNodes);
- recycleIf(node -> node.allocation().get().membership().cluster().isStateful() &&
- node.history().hasEventBefore(History.Event.Type.failed, clock().instant().minus(statefulExpiry)),
- remainingNodes,
- allNodes);
+ nodeRepository.nodes().performOn(allNodes.nodeType(NodeType.tenant),
+ isExpired,
+ (node, lock) -> recycle(node, List.of(), allNodes).get());
+
+ nodeRepository.nodes().performOnRecursively(allNodes.nodeType(NodeType.host),
+ nodes -> isExpired.test(nodes.parent().node()),
+ nodes -> recycle(nodes.parent().node(),
+ nodes.children().stream().map(NodeMutex::node).toList(),
+ allNodes)
+ .map(List::of).orElse(List.of()));
return 1.0;
}
- /** Recycle the nodes matching condition, and remove those nodes from the nodes list. */
- private void recycleIf(Predicate<Node> condition, List<Node> failedNodes, NodeList allNodes) {
- List<Node> nodesToRecycle = failedNodes.stream().filter(condition).toList();
- failedNodes.removeAll(nodesToRecycle);
- recycle(nodesToRecycle, allNodes);
+ private Duration expiryFor(Node node) {
+ return node.allocation().isEmpty() ? Duration.ZERO
+ : node.allocation().get().membership().cluster().isStateful() ? statefulExpiry
+ : statelessExpiry;
}
- /** Move eligible nodes to dirty or parked. This may be a subset of the given nodes */
- private void recycle(List<Node> nodes, NodeList allNodes) {
- List<Node> nodesToRecycle = new ArrayList<>();
- for (Node candidate : nodes) {
- Optional<String> reason = shouldPark(candidate, allNodes);
- if (reason.isPresent()) {
- List<String> unparkedChildren = candidate.type().isHost() ?
- allNodes.childrenOf(candidate)
- .not()
- .state(Node.State.parked)
- .mapToList(Node::hostname) :
- List.of();
-
- if (unparkedChildren.isEmpty()) {
- nodeRepository.nodes().park(candidate.hostname(), true, Agent.FailedExpirer,
- "Parked by FailedExpirer due to " + reason.get());
- } else {
- log.info(String.format("Expired failed node %s was not parked because of unparked children: %s",
- candidate.hostname(), String.join(", ", unparkedChildren)));
- }
+ private Optional<Node> recycle(Node node, List<Node> children, NodeList allNodes) {
+ Optional<String> reason = shouldPark(node, allNodes);
+ if (reason.isPresent()) {
+ List<String> unparkedChildren = children.stream()
+ .filter(child -> child.state() != Node.State.parked)
+ .map(Node::hostname)
+ .toList();
+ if (unparkedChildren.isEmpty()) {
+ return Optional.of(nodeRepository.nodes().park(node.hostname(), true, Agent.FailedExpirer,
+ "Parked by FailedExpirer due to " + reason.get()));
} else {
- nodesToRecycle.add(candidate);
+ log.info(String.format("Expired failed node %s was not parked because of unparked children: %s",
+ node.hostname(), String.join(", ", unparkedChildren)));
+ return Optional.empty();
}
+ } else {
+ return Optional.of(nodeRepository.nodes().deallocate(node, Agent.FailedExpirer, "Expired by FailedExpirer"));
}
- nodeRepository.nodes().deallocate(nodesToRecycle, Agent.FailedExpirer, "Expired by FailedExpirer");
}
/** Returns whether the node should be parked instead of recycled */
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java
index aa7aac34389..503ac4be86c 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java
@@ -3,6 +3,8 @@ package com.yahoo.vespa.hosted.provision.maintenance;
import com.yahoo.jdisc.Metric;
import com.yahoo.vespa.hosted.provision.Node;
+import com.yahoo.vespa.hosted.provision.Node.State;
+import com.yahoo.vespa.hosted.provision.NodeList;
import com.yahoo.vespa.hosted.provision.NodeRepository;
import com.yahoo.vespa.hosted.provision.node.Agent;
import com.yahoo.vespa.hosted.provision.node.History;
@@ -40,9 +42,9 @@ public class InactiveExpirer extends Expirer {
@Override
protected void expire(List<Node> expired) {
- expired.forEach(node -> {
- nodeRepository.nodes().deallocate(node, Agent.InactiveExpirer, "Expired by InactiveExpirer");
- });
+ nodeRepository.nodes().performOn(NodeList.copyOf(expired),
+ node -> node.state() == State.inactive && isExpired(node),
+ (node, lock) -> nodeRepository.nodes().deallocate(node, Agent.InactiveExpirer, "Expired by InactiveExpirer"));
}
@Override
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java
index 6f06a2ac22e..2484f496ece 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java
@@ -25,6 +25,8 @@ public class ReservationExpirer extends Expirer {
}
@Override
- protected void expire(List<Node> expired) { nodeRepository().nodes().deallocate(expired, Agent.ReservationExpirer, "Expired by ReservationExpirer"); }
+ protected void expire(List<Node> expired) {
+ nodeRepository().nodes().deallocate(expired, Agent.ReservationExpirer, "Expired by ReservationExpirer");
+ }
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java
index cc7db3c138a..1ff6d2b300d 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java
@@ -113,7 +113,7 @@ public record IP() {
*
* @throws IllegalArgumentException if there are IP conflicts with existing nodes
*/
- public static List<Node> verify(List<Node> nodes, LockedNodeList allNodes) {
+ public static LockedNodeList verify(List<Node> nodes, LockedNodeList allNodes) {
NodeList sortedNodes = allNodes.sortedBy(Comparator.comparing(Node::hostname));
for (var node : nodes) {
for (var other : sortedNodes) {
@@ -135,7 +135,7 @@ public record IP() {
other.hostname());
}
}
- return nodes;
+ return allNodes.childList(nodes);
}
/** Returns whether IP address of existing node can be assigned to node */
@@ -152,7 +152,7 @@ public record IP() {
}
public static Node verify(Node node, LockedNodeList allNodes) {
- return verify(List.of(node), allNodes).get(0);
+ return verify(List.of(node), allNodes).asList().get(0);
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java
index deaf3054362..490e7b9ac33 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java
@@ -1,7 +1,6 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.provision.node;
-import com.yahoo.collections.ListMap;
import com.yahoo.component.Version;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.ApplicationTransaction;
@@ -10,6 +9,7 @@ import com.yahoo.config.provision.Flavor;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.Zone;
+import com.yahoo.time.TimeBudget;
import com.yahoo.transaction.Mutex;
import com.yahoo.transaction.NestedTransaction;
import com.yahoo.vespa.applicationmodel.HostName;
@@ -17,6 +17,7 @@ import com.yahoo.vespa.applicationmodel.InfrastructureApplication;
import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.NoSuchNodeException;
import com.yahoo.vespa.hosted.provision.Node;
+import com.yahoo.vespa.hosted.provision.Node.State;
import com.yahoo.vespa.hosted.provision.NodeList;
import com.yahoo.vespa.hosted.provision.NodeMutex;
import com.yahoo.vespa.hosted.provision.applications.Applications;
@@ -31,20 +32,26 @@ import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Comparator;
import java.util.EnumSet;
+import java.util.HashSet;
+import java.util.Iterator;
import java.util.List;
-import java.util.Map;
-import java.util.Objects;
+import java.util.NavigableSet;
import java.util.Optional;
import java.util.Set;
+import java.util.TreeSet;
import java.util.function.BiFunction;
+import java.util.function.Function;
import java.util.function.Predicate;
import java.util.logging.Level;
import java.util.logging.Logger;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
import static com.yahoo.vespa.hosted.provision.restapi.NodePatcher.DROP_DOCUMENTS_REPORT;
+import static java.util.Comparator.comparing;
+import static java.util.stream.Collectors.groupingBy;
+import static java.util.stream.Collectors.joining;
/**
* The nodes in the node repo and their state transitions
@@ -148,7 +155,7 @@ public class Nodes {
if (existing.isPresent())
throw new IllegalStateException("Cannot add " + node + ": A node with this name already exists");
}
- return db.addNodesInState(nodes.asList(), Node.State.reserved, Agent.system);
+ return db.addNodesInState(nodes, Node.State.reserved, Agent.system);
}
/**
@@ -157,7 +164,8 @@ public class Nodes {
* with the history of that node.
*/
public List<Node> addNodes(List<Node> nodes, Agent agent) {
- try (Mutex lock = lockUnallocated()) {
+ try (NodeMutexes existingNodesLocks = lockAndGetAll(nodes, Optional.empty()); // Locks for any existing nodes we may remove.
+ Mutex allocationLock = lockUnallocated()) {
List<Node> nodesToAdd = new ArrayList<>();
List<Node> nodesToRemove = new ArrayList<>();
for (int i = 0; i < nodes.size(); i++) {
@@ -194,7 +202,7 @@ public class Nodes {
}
NestedTransaction transaction = new NestedTransaction();
db.removeNodes(nodesToRemove, transaction);
- List<Node> resultingNodes = db.addNodesInState(IP.Config.verify(nodesToAdd, list(lock)), Node.State.provisioned, agent, transaction);
+ List<Node> resultingNodes = db.addNodesInState(IP.Config.verify(nodesToAdd, list(allocationLock)), Node.State.provisioned, agent, transaction);
transaction.commit();
return resultingNodes;
}
@@ -218,7 +226,7 @@ public class Nodes {
}
/** Activate nodes. This method does <b>not</b> lock the node repository. */
- public List<Node> activate(List<Node> nodes, NestedTransaction transaction) {
+ public List<Node> activate(List<Node> nodes, ApplicationTransaction transaction) {
return db.writeTo(Node.State.active, nodes, Agent.application, Optional.empty(), transaction);
}
@@ -229,8 +237,7 @@ public class Nodes {
* @param reusable move the node directly to {@link Node.State#dirty} after removal
*/
public void setRemovable(NodeList nodes, boolean reusable) {
- performOn(nodes, (node, mutex) -> write(node.with(node.allocation().get().removable(true, reusable)),
- mutex));
+ performOn(nodes, (node, mutex) -> write(node.with(node.allocation().get().removable(true, reusable)), mutex));
}
/**
@@ -239,7 +246,7 @@ public class Nodes {
*/
public List<Node> deactivate(List<Node> nodes, ApplicationTransaction transaction) {
if ( ! zone.environment().isProduction() || zone.system().isCd())
- return deallocate(nodes, Agent.application, "Deactivated by application", transaction.nested());
+ return deallocate(nodes, Agent.application, "Deactivated by application", transaction);
NodeList nodeList = NodeList.copyOf(nodes);
NodeList stateless = nodeList.stateless();
@@ -247,9 +254,9 @@ public class Nodes {
NodeList statefulToInactive = stateful.not().reusable();
NodeList statefulToDirty = stateful.reusable();
List<Node> written = new ArrayList<>();
- written.addAll(deallocate(stateless.asList(), Agent.application, "Deactivated by application", transaction.nested()));
- written.addAll(deallocate(statefulToDirty.asList(), Agent.application, "Deactivated by application (recycled)", transaction.nested()));
- written.addAll(db.writeTo(Node.State.inactive, statefulToInactive.asList(), Agent.application, Optional.empty(), transaction.nested()));
+ written.addAll(deallocate(stateless.asList(), Agent.application, "Deactivated by application", transaction));
+ written.addAll(deallocate(statefulToDirty.asList(), Agent.application, "Deactivated by application (recycled)", transaction));
+ written.addAll(db.writeTo(Node.State.inactive, statefulToInactive.asList(), Agent.application, Optional.empty(), transaction));
return written;
}
@@ -258,21 +265,9 @@ public class Nodes {
* transaction commits.
*/
public List<Node> fail(List<Node> nodes, ApplicationTransaction transaction) {
- return fail(nodes, Agent.application, "Failed by application", transaction.nested());
- }
-
- public List<Node> fail(List<Node> nodes, Agent agent, String reason) {
- NestedTransaction transaction = new NestedTransaction();
- nodes = fail(nodes, agent, reason, transaction);
- transaction.commit();
- return nodes;
- }
-
- private List<Node> fail(List<Node> nodes, Agent agent, String reason, NestedTransaction transaction) {
- nodes = nodes.stream()
- .map(n -> n.withWantToFail(false, agent, clock.instant()))
- .toList();
- return db.writeTo(Node.State.failed, nodes, agent, Optional.of(reason), transaction);
+ return db.writeTo(Node.State.failed,
+ nodes.stream().map(n -> n.withWantToFail(false, Agent.application, clock.instant())).toList(),
+ Agent.application, Optional.of("Failed by application"), transaction);
}
/** Move nodes to the dirty state */
@@ -282,40 +277,48 @@ public class Nodes {
public List<Node> deallocateRecursively(String hostname, Agent agent, String reason) {
Node nodeToDirty = node(hostname).orElseThrow(() -> new NoSuchNodeException("Could not deallocate " + hostname + ": Node not found"));
-
- List<Node> nodesToDirty =
- (nodeToDirty.type().isHost() ?
- Stream.concat(list().childrenOf(hostname).asList().stream(), Stream.of(nodeToDirty)) :
- Stream.of(nodeToDirty)).filter(node -> node.state() != Node.State.dirty).toList();
- List<String> hostnamesNotAllowedToDirty = nodesToDirty.stream()
- .filter(node -> node.state() != Node.State.provisioned)
- .filter(node -> node.state() != Node.State.failed)
- .filter(node -> node.state() != Node.State.parked)
- .filter(node -> node.state() != Node.State.breakfixed)
- .map(Node::hostname).toList();
- if ( ! hostnamesNotAllowedToDirty.isEmpty())
- illegal("Could not deallocate " + nodeToDirty + ": " +
- hostnamesNotAllowedToDirty + " are not in states [provisioned, failed, parked, breakfixed]");
-
- return nodesToDirty.stream().map(node -> deallocate(node, agent, reason)).toList();
+ List<Node> nodesToDirty = new ArrayList<>();
+ try (RecursiveNodeMutexes locked = lockAndGetRecursively(hostname, Optional.empty())) {
+ for (NodeMutex child : locked.children())
+ if (child.node().state() != Node.State.dirty)
+ nodesToDirty.add(child.node());
+
+ if (locked.parent().node().state() != State.dirty)
+ nodesToDirty.add(locked.parent().node());
+
+ List<String> hostnamesNotAllowedToDirty = nodesToDirty.stream()
+ .filter(node -> node.state() != Node.State.provisioned)
+ .filter(node -> node.state() != Node.State.failed)
+ .filter(node -> node.state() != Node.State.parked)
+ .filter(node -> node.state() != Node.State.breakfixed)
+ .map(Node::hostname).toList();
+ if ( ! hostnamesNotAllowedToDirty.isEmpty())
+ illegal("Could not deallocate " + nodeToDirty + ": " +
+ hostnamesNotAllowedToDirty + " are not in states [provisioned, failed, parked, breakfixed]");
+
+ return nodesToDirty.stream().map(node -> deallocate(node, agent, reason)).toList();
+ }
}
/**
- * Set a node dirty or parked, allowed if it is in the provisioned, inactive, failed or parked state.
+ * Set a node dirty or parked, allowed if it is in the provisioned, inactive, failed or parked state.
* Use this to clean newly provisioned nodes or to recycle failed nodes which have been repaired or put on hold.
*/
public Node deallocate(Node node, Agent agent, String reason) {
- NestedTransaction transaction = new NestedTransaction();
- Node deallocated = deallocate(node, agent, reason, transaction);
- transaction.commit();
- return deallocated;
+ try (NodeMutex locked = lockAndGetRequired(node)) {
+ NestedTransaction transaction = new NestedTransaction();
+ Node deallocated = deallocate(locked.node(), agent, reason, transaction);
+ transaction.commit();
+ return deallocated;
+ }
}
- public List<Node> deallocate(List<Node> nodes, Agent agent, String reason, NestedTransaction transaction) {
- return nodes.stream().map(node -> deallocate(node, agent, reason, transaction)).toList();
+ public List<Node> deallocate(List<Node> nodes, Agent agent, String reason, ApplicationTransaction transaction) {
+ return nodes.stream().map(node -> deallocate(node, agent, reason, transaction.nested())).toList();
}
- public Node deallocate(Node node, Agent agent, String reason, NestedTransaction transaction) {
+ // Be sure to hold the right lock!
+ private Node deallocate(Node node, Agent agent, String reason, NestedTransaction transaction) {
if (parkOnDeallocationOf(node, agent)) {
return park(node.hostname(), false, agent, reason, transaction);
} else {
@@ -339,7 +342,9 @@ public class Nodes {
}
public Node fail(String hostname, boolean forceDeprovision, Agent agent, String reason) {
- return move(hostname, Node.State.failed, agent, forceDeprovision, Optional.of(reason));
+ try (NodeMutex lock = lockAndGetRequired(hostname)) {
+ return move(hostname, Node.State.failed, agent, forceDeprovision, Optional.of(reason), lock);
+ }
}
/**
@@ -350,14 +355,16 @@ public class Nodes {
* @return all the nodes that were changed by this request
*/
public List<Node> failOrMarkRecursively(String hostname, Agent agent, String reason) {
- NodeList children = list().childrenOf(hostname);
- List<Node> changed = performOn(children, (node, lock) -> failOrMark(node, agent, reason, lock));
-
- if (children.state(Node.State.active).isEmpty())
- changed.add(move(hostname, Node.State.failed, agent, false, Optional.of(reason)));
- else
- changed.addAll(performOn(NodeList.of(node(hostname).orElseThrow()), (node, lock) -> failOrMark(node, agent, reason, lock)));
+ List<Node> changed = new ArrayList<>();
+ try (RecursiveNodeMutexes nodes = lockAndGetRecursively(hostname, Optional.empty())) {
+ for (NodeMutex child : nodes.children())
+ changed.add(failOrMark(child.node(), agent, reason, child));
+ if (changed.stream().noneMatch(child -> child.state() == Node.State.active))
+ changed.add(move(hostname, Node.State.failed, agent, false, Optional.of(reason), nodes.parent()));
+ else
+ changed.add(failOrMark(nodes.parent().node(), agent, reason, nodes.parent()));
+ }
return changed;
}
@@ -367,7 +374,7 @@ public class Nodes {
write(node, lock);
return node;
} else {
- return move(node.hostname(), Node.State.failed, agent, false, Optional.of(reason));
+ return move(node.hostname(), Node.State.failed, agent, false, Optional.of(reason), lock);
}
}
@@ -389,10 +396,12 @@ public class Nodes {
* @throws NoSuchNodeException if the node is not found
*/
public Node park(String hostname, boolean forceDeprovision, Agent agent, String reason) {
- NestedTransaction transaction = new NestedTransaction();
- Node parked = park(hostname, forceDeprovision, agent, reason, transaction);
- transaction.commit();
- return parked;
+ try (NodeMutex locked = lockAndGetRequired(hostname)) {
+ NestedTransaction transaction = new NestedTransaction();
+ Node parked = park(hostname, forceDeprovision, agent, reason, transaction);
+ transaction.commit();
+ return parked;
+ }
}
private Node park(String hostname, boolean forceDeprovision, Agent agent, String reason, NestedTransaction transaction) {
@@ -415,36 +424,38 @@ public class Nodes {
* @throws NoSuchNodeException if the node is not found
*/
public Node reactivate(String hostname, Agent agent, String reason) {
- return move(hostname, Node.State.active, agent, false, Optional.of(reason));
+ try (NodeMutex lock = lockAndGetRequired(hostname)) {
+ return move(hostname, Node.State.active, agent, false, Optional.of(reason), lock);
+ }
}
/**
* Moves a host to breakfixed state, removing any children.
*/
public List<Node> breakfixRecursively(String hostname, Agent agent, String reason) {
- Node node = requireNode(hostname);
- try (Mutex lock = lockUnallocated()) {
- requireBreakfixable(node);
+ try (RecursiveNodeMutexes locked = lockAndGetRecursively(hostname, Optional.empty())) {
+ requireBreakfixable(locked.parent().node());
NestedTransaction transaction = new NestedTransaction();
- List<Node> removed = removeChildren(node, false, transaction);
- removed.add(move(node.hostname(), Node.State.breakfixed, agent, false, Optional.of(reason), transaction));
+ removeChildren(locked, false, transaction);
+ move(hostname, Node.State.breakfixed, agent, false, Optional.of(reason), transaction);
transaction.commit();
- return removed;
+ return locked.nodes().nodes().stream().map(NodeMutex::node).toList();
}
}
private List<Node> moveRecursively(String hostname, Node.State toState, Agent agent, Optional<String> reason) {
- NestedTransaction transaction = new NestedTransaction();
- List<Node> moved = list().childrenOf(hostname).asList().stream()
- .map(child -> move(child.hostname(), toState, agent, false, reason, transaction))
- .collect(Collectors.toCollection(ArrayList::new));
- moved.add(move(hostname, toState, agent, false, reason, transaction));
- transaction.commit();
- return moved;
+ try (RecursiveNodeMutexes locked = lockAndGetRecursively(hostname, Optional.empty())) {
+ List<Node> moved = new ArrayList<>();
+ NestedTransaction transaction = new NestedTransaction();
+ for (NodeMutex node : locked.nodes().nodes())
+ moved.add(move(node.node().hostname(), toState, agent, false, reason, transaction));
+ transaction.commit();
+ return moved;
+ }
}
/** Move a node to given state */
- private Node move(String hostname, Node.State toState, Agent agent, boolean forceDeprovision, Optional<String> reason) {
+ private Node move(String hostname, Node.State toState, Agent agent, boolean forceDeprovision, Optional<String> reason, Mutex lock) {
NestedTransaction transaction = new NestedTransaction();
Node moved = move(hostname, toState, agent, forceDeprovision, reason, transaction);
transaction.commit();
@@ -453,8 +464,7 @@ public class Nodes {
/** Move a node to given state as part of a transaction */
private Node move(String hostname, Node.State toState, Agent agent, boolean forceDeprovision, Optional<String> reason, NestedTransaction transaction) {
- // TODO: Work out a safe lock acquisition strategy for moves. Lock is only held while adding operations to
- // transaction, but lock must also be held while committing
+ // TODO: This lock is already held here, but we still need to read the node. Perhaps change to requireNode(hostname) later.
try (NodeMutex lock = lockAndGetRequired(hostname)) {
Node node = lock.node();
if (toState == Node.State.active) {
@@ -523,17 +533,18 @@ public class Nodes {
}
public List<Node> removeRecursively(Node node, boolean force) {
- try (Mutex lock = lockUnallocated()) {
- requireRemovable(node, false, force);
+ try (RecursiveNodeMutexes locked = lockAndGetRecursively(node.hostname(), Optional.empty())) {
+ requireRemovable(locked.parent().node(), false, force);
NestedTransaction transaction = new NestedTransaction();
List<Node> removed;
- if (!node.type().isHost()) {
+ if ( ! node.type().isHost()) {
removed = List.of(node);
db.removeNodes(removed, transaction);
- } else {
- removed = removeChildren(node, force, transaction);
+ }
+ else {
+ removeChildren(locked, force, transaction);
move(node.hostname(), Node.State.deprovisioned, Agent.system, false, Optional.empty(), transaction);
- removed.add(node);
+ removed = locked.nodes().nodes().stream().map(NodeMutex::node).toList();
}
transaction.commit();
return removed;
@@ -542,20 +553,22 @@ public class Nodes {
/** Forgets a deprovisioned node. This removes all traces of the node in the node repository. */
public void forget(Node node) {
- if (node.state() != Node.State.deprovisioned)
- throw new IllegalArgumentException(node + " must be deprovisioned before it can be forgotten");
- if (node.status().wantToRebuild())
- throw new IllegalArgumentException(node + " is rebuilding and cannot be forgotten");
- NestedTransaction transaction = new NestedTransaction();
- db.removeNodes(List.of(node), transaction);
- transaction.commit();
+ try (NodeMutex locked = lockAndGetRequired(node.hostname())) {
+ if (node.state() != Node.State.deprovisioned)
+ throw new IllegalArgumentException(node + " must be deprovisioned before it can be forgotten");
+ if (node.status().wantToRebuild())
+ throw new IllegalArgumentException(node + " is rebuilding and cannot be forgotten");
+ NestedTransaction transaction = new NestedTransaction();
+ db.removeNodes(List.of(node), transaction);
+ transaction.commit();
+ }
}
- private List<Node> removeChildren(Node node, boolean force, NestedTransaction transaction) {
- List<Node> children = list().childrenOf(node).asList();
+ private void removeChildren(RecursiveNodeMutexes nodes, boolean force, NestedTransaction transaction) {
+ if (nodes.children().isEmpty()) return;
+ List<Node> children = nodes.children().stream().map(NodeMutex::node).toList();
children.forEach(child -> requireRemovable(child, true, force));
db.removeNodes(children, transaction);
- return new ArrayList<>(children);
}
/**
@@ -717,8 +730,8 @@ public class Nodes {
return db.writeTo(nodes, Agent.system, Optional.empty());
}
- private List<Node> performOn(Predicate<Node> filter, BiFunction<Node, Mutex, Node> action) {
- return performOn(list().matching(filter), action);
+ public List<Node> performOn(Predicate<Node> filter, BiFunction<Node, Mutex, Node> action) {
+ return performOn(list(), filter, action);
}
/**
@@ -727,35 +740,33 @@ public class Nodes {
* @param action the action to perform
* @return the set of nodes on which the action was performed, as they became as a result of the operation
*/
- private List<Node> performOn(NodeList nodes, BiFunction<Node, Mutex, Node> action) {
- List<Node> unallocatedNodes = new ArrayList<>();
- ListMap<ApplicationId, Node> allocatedNodes = new ListMap<>();
+ public List<Node> performOn(NodeList nodes, BiFunction<Node, Mutex, Node> action) {
+ return performOn(nodes, __ -> true, action);
+ }
- // Group matching nodes by the lock needed
- for (Node node : nodes) {
- Optional<ApplicationId> applicationId = applicationIdForLock(node);
- if (applicationId.isPresent())
- allocatedNodes.put(applicationId.get(), node);
- else
- unallocatedNodes.add(node);
- }
+ public List<Node> performOn(NodeList nodes, Predicate<Node> filter, BiFunction<Node, Mutex, Node> action) {
+ List<Node> resultingNodes = new ArrayList<>();
+ nodes.stream().collect(groupingBy(Nodes::applicationIdForLock))
+ .forEach((applicationId, nodeList) -> { // Grouped only to reduce number of lock acquire/release cycles.
+ try (NodeMutexes locked = lockAndGetAll(nodeList, Optional.empty())) {
+ for (NodeMutex node : locked.nodes())
+ if (filter.test(node.node()))
+ resultingNodes.add(action.apply(node.node(), node));
+ }
+ });
+ return resultingNodes;
+ }
+
+ public List<Node> performOnRecursively(NodeList parents, Predicate<RecursiveNodeMutexes> filter, Function<RecursiveNodeMutexes, List<Node>> action) {
+ for (Node node : parents)
+ if (node.parentHostname().isPresent())
+ throw new IllegalArgumentException(node + " is not a parent host");
- // Perform operation while holding appropriate lock
List<Node> resultingNodes = new ArrayList<>();
- try (Mutex lock = lockUnallocated()) {
- for (Node node : unallocatedNodes) {
- Optional<Node> currentNode = db.readNode(node.hostname()); // Re-read while holding lock
- if (currentNode.isEmpty()) continue;
- resultingNodes.add(action.apply(currentNode.get(), lock));
- }
- }
- for (Map.Entry<ApplicationId, List<Node>> applicationNodes : allocatedNodes.entrySet()) {
- try (Mutex lock = applications.lock(applicationNodes.getKey())) {
- for (Node node : applicationNodes.getValue()) {
- Optional<Node> currentNode = db.readNode(node.hostname()); // Re-read while holding lock
- if (currentNode.isEmpty()) continue;
- resultingNodes.add(action.apply(currentNode.get(), lock));
- }
+ for (Node parent : parents) {
+ try (RecursiveNodeMutexes locked = lockAndGetRecursively(parent.hostname(), Optional.empty())) {
+ if (filter.test(locked))
+ resultingNodes.addAll(action.apply(locked));
}
}
return resultingNodes;
@@ -818,9 +829,7 @@ public class Nodes {
return Optional.empty();
}
- if (node.type() != NodeType.tenant ||
- Objects.equals(freshNode.get().allocation().map(Allocation::owner),
- staleNode.allocation().map(Allocation::owner))) {
+ if (applicationIdForLock(freshNode.get()).equals(applicationIdForLock(staleNode))) {
NodeMutex nodeMutex = new NodeMutex(freshNode.get(), lockToClose);
lockToClose = null;
return Optional.of(nodeMutex);
@@ -881,6 +890,168 @@ public class Nodes {
return node(hostname).orElseThrow(() -> new NoSuchNodeException("No node with hostname '" + hostname + "'"));
}
+ /**
+ * Locks the children of the given node, the node itself, and finally takes the unallocated lock.
+ * <br>
+ * When taking multiple locks, it's crucial that we always take them in the same order, to avoid deadlocks.
+ * We want to take the most contended locks last, so that we don't block other operations for longer than necessary.
+ * This method does that, by first taking the locks for any children the given node may have, and then the node itself.
+ * (This is enforced by taking host locks after tenant node locks, in {@link #lockAndGetAll(Collection, Optional)}.)
+ * Finally, the allocation lock is taken, to ensure no new children are added while we hold this snapshot.
+ * Unfortunately, since that lock is taken last, we may detect new nodes after taking it, and then we have to retry.
+ * Closing the returned {@link RecursiveNodeMutexes} will release all the locks, and the locks should not be closed elsewhere.
+ */
+ public RecursiveNodeMutexes lockAndGetRecursively(String hostname, Optional<Duration> timeout) {
+ TimeBudget budget = TimeBudget.fromNow(clock, timeout.orElse(Duration.ofMinutes(2)));
+ Set<Node> children = new HashSet<>(list().childrenOf(hostname).asList());
+ Optional<Node> node = node(hostname);
+
+ int attempts = 5; // We'll retry locking the whole list of children this many times, in case new children appear.
+ for (int attempt = 0; attempt < attempts; attempt++) {
+ NodeMutexes mutexes = null;
+ Mutex unallocatedLock = null;
+ try {
+ // First, we lock all the children, and the host; then we take the allocation lock to ensure our snapshot is valid.
+ List<Node> nodes = new ArrayList<>(children.size() + 1);
+ nodes.addAll(children);
+ node.ifPresent(nodes::add);
+ mutexes = lockAndGetAll(nodes, budget.timeLeftOrThrow());
+ unallocatedLock = db.lockInactive(budget.timeLeftOrThrow().get());
+ RecursiveNodeMutexes recursive = new RecursiveNodeMutexes(hostname, mutexes, unallocatedLock);
+ Set<Node> freshChildren = list().childrenOf(hostname).asSet();
+ Optional<Node> freshNode = recursive.parent.map(NodeMutex::node);
+ if (children.equals(freshChildren) && node.equals(freshNode)) {
+ // No new nodes have appeared, and none will now, so we have a consistent snapshot.
+ if (node.isEmpty() && ! children.isEmpty())
+ throw new IllegalStateException("node '" + hostname + "' was not found, but it has children: " + children);
+
+ mutexes = null;
+ unallocatedLock = null;
+ return recursive;
+ }
+ else {
+ // New nodes have appeared, so we need to let go of the locks and try again with the new set of nodes.
+ children = freshChildren;
+ node = freshNode;
+ }
+ }
+ finally {
+ if (unallocatedLock != null) unallocatedLock.close();
+ if (mutexes != null) mutexes.close();
+ }
+ }
+ throw new IllegalStateException("giving up (after " + attempts + " attempts) fetching an up to " +
+ "date recursive node set under lock for node " + hostname);
+ }
+
+ /** Locks all nodes in the given list, in a universal order, and returns the locks and nodes required. */
+ public NodeMutexes lockAndRequireAll(Collection<Node> nodes, Optional<Duration> timeout) {
+ return lockAndGetAll(nodes, timeout, true);
+ }
+
+ /** Locks all nodes in the given list, in a universal order, and returns the locks and nodes acquired. */
+ public NodeMutexes lockAndGetAll(Collection<Node> nodes, Optional<Duration> timeout) {
+ return lockAndGetAll(nodes, timeout, false);
+ }
+
+ /** Locks all nodes in the given list, in a universal order, and returns the locks and nodes. */
+ private NodeMutexes lockAndGetAll(Collection<Node> nodes, Optional<Duration> timeout, boolean required) {
+ TimeBudget budget = TimeBudget.fromNow(clock, timeout.orElse(Duration.ofMinutes(2)));
+ Comparator<Node> universalOrder = (a, b) -> {
+ Optional<ApplicationId> idA = applicationIdForLock(a);
+ Optional<ApplicationId> idB = applicationIdForLock(b);
+ if (idA.isPresent() != idB.isPresent()) return idA.isPresent() ? -1 : 1; // Allocated nodes first.
+ if (a.type() != b.type()) return a.type().compareTo(b.type()); // Tenant nodes first among those.
+ if ( ! idA.equals(idB)) return idA.get().compareTo(idB.get()); // Sort primarily by tenant owner id.
+ return a.hostname().compareTo(b.hostname()); // Sort secondarily by hostname.
+ };
+ NavigableSet<NodeMutex> locked = new TreeSet<>(comparing(NodeMutex::node, universalOrder));
+ NavigableSet<Node> unlocked = new TreeSet<>(universalOrder);
+ unlocked.addAll(nodes);
+ try {
+ int attempts = 10; // We'll accept getting the wrong lock at most this many times before giving up.
+ for (int attempt = 0; attempt < attempts; ) {
+ if (unlocked.isEmpty()) {
+ NodeMutexes mutexes = new NodeMutexes(List.copyOf(locked));
+ locked.clear();
+ return mutexes;
+ }
+
+ // If the first node is now earlier in lock order than some other locks we have, we need to close those and re-acquire them.
+ Node next = unlocked.pollFirst();
+ Set<NodeMutex> outOfOrder = locked.tailSet(new NodeMutex(next, () -> { }), false);
+ NodeMutexes.close(outOfOrder.iterator());
+ for (NodeMutex node : outOfOrder) unlocked.add(node.node());
+ outOfOrder.clear();
+
+ Mutex lock = lock(next, budget.timeLeftOrThrow());
+ try {
+ Optional<Node> fresh = node(next.hostname());
+ if (fresh.isEmpty()) {
+ if (required) throw new NoSuchNodeException("No node with hostname '" + next.hostname() + "'");
+ continue; // Node is gone; skip to close lock.
+ }
+
+ if (applicationIdForLock(fresh.get()).equals(applicationIdForLock(next))) {
+ // We held the right lock, so this node is ours now.
+ locked.add(new NodeMutex(fresh.get(), lock));
+ lock = null;
+ }
+ else {
+ // We held the wrong lock, and need to try again.
+ ++attempt;
+ unlocked.add(fresh.get());
+ }
+ }
+ finally {
+ // If we didn't hold the right lock, we must close the wrong one before we continue.
+ if (lock != null) lock.close();
+ }
+ }
+ throw new IllegalStateException("giving up (after " + attempts + " extra attempts) to lock nodes: " +
+ nodes.stream().map(Node::hostname).collect(joining(", ")));
+ }
+ finally {
+ // If we didn't manage to lock all nodes, we must close the ones we did lock before we throw.
+ NodeMutexes.close(locked.iterator());
+ }
+ }
+
+ /** A node with their locks, acquired in a universal order. */
+ public record NodeMutexes(List<NodeMutex> nodes) implements AutoCloseable {
+ @Override public void close() { close(nodes.iterator()); }
+ private static void close(Iterator<NodeMutex> nodes) {
+ if (nodes.hasNext()) try (NodeMutex node = nodes.next()) { close(nodes); }
+ }
+ }
+
+ /** A parent node, all its children, their locks acquired in a universal order, and then the unallocated lock. */
+ public static class RecursiveNodeMutexes implements AutoCloseable {
+
+ private final String hostname;
+ private final NodeMutexes nodes;
+ private final Mutex unallocatedLock;
+ private final List<NodeMutex> children;
+ private final Optional<NodeMutex> parent;
+
+ public RecursiveNodeMutexes(String hostname, NodeMutexes nodes, Mutex unallocatedLock) {
+ this.hostname = hostname;
+ this.nodes = nodes;
+ this.unallocatedLock = unallocatedLock;
+ this.children = nodes.nodes().stream().filter(node -> ! node.node().hostname().equals(hostname)).toList();
+ this.parent = nodes.nodes().stream().filter(node -> node.node().hostname().equals(hostname)).findFirst();
+ }
+
+ /** Any children of the node. */
+ public List<NodeMutex> children() { return children; }
+ /** The node itself, or throws if the node was not found. */
+ public NodeMutex parent() { return parent.orElseThrow(() -> new NoSuchNodeException("No node with hostname '" + hostname + "'")); }
+ /** Empty if the node was not found, or the node, and any children. */
+ public NodeMutexes nodes() { return nodes; }
+ /** Closes the allocation lock, and all the node locks. */
+ @Override public void close() { try (nodes; unallocatedLock) { } }
+ }
+
/** Returns the application ID that should be used for locking when modifying this node */
private static Optional<ApplicationId> applicationIdForLock(Node node) {
return switch (node.type()) {
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java
index fc008b7b9dc..037338cb2ed 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java
@@ -18,6 +18,7 @@ import com.yahoo.vespa.curator.Lock;
import com.yahoo.vespa.curator.recipes.CuratorCounter;
import com.yahoo.vespa.curator.transaction.CuratorOperations;
import com.yahoo.vespa.curator.transaction.CuratorTransaction;
+import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.applications.Application;
import com.yahoo.vespa.hosted.provision.archive.ArchiveUris;
@@ -105,7 +106,7 @@ public class CuratorDb {
}
/** Adds a set of nodes. Rollbacks/fails transaction if any node is not in the expected state. */
- public List<Node> addNodesInState(List<Node> nodes, Node.State expectedState, Agent agent, NestedTransaction transaction) {
+ public List<Node> addNodesInState(LockedNodeList nodes, Node.State expectedState, Agent agent, NestedTransaction transaction) {
CuratorTransaction curatorTransaction = db.newCuratorTransactionIn(transaction);
for (Node node : nodes) {
if (node.state() != expectedState)
@@ -116,10 +117,10 @@ public class CuratorDb {
curatorTransaction.add(CuratorOperations.create(nodePath(node).getAbsolute(), serialized));
}
transaction.onCommitted(() -> nodes.forEach(node -> log.log(Level.INFO, "Added " + node)));
- return nodes;
+ return nodes.asList();
}
- public List<Node> addNodesInState(List<Node> nodes, Node.State expectedState, Agent agent) {
+ public List<Node> addNodesInState(LockedNodeList nodes, Node.State expectedState, Agent agent) {
NestedTransaction transaction = new NestedTransaction();
List<Node> writtenNodes = addNodesInState(nodes, expectedState, agent, transaction);
transaction.commit();
@@ -175,6 +176,7 @@ public class CuratorDb {
return writtenNodes;
}
}
+
public Node writeTo(Node.State toState, Node node, Agent agent, Optional<String> reason) {
return writeTo(toState, Collections.singletonList(node), agent, reason).get(0);
}
@@ -192,6 +194,12 @@ public class CuratorDb {
*/
public List<Node> writeTo(Node.State toState, List<Node> nodes,
Agent agent, Optional<String> reason,
+ ApplicationTransaction transaction) {
+ return writeTo(toState, nodes, agent, reason, transaction.nested());
+ }
+
+ public List<Node> writeTo(Node.State toState, List<Node> nodes,
+ Agent agent, Optional<String> reason,
NestedTransaction transaction) {
if (nodes.isEmpty()) return nodes;
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java
index caf936e8aeb..c25f33bc8c2 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java
@@ -88,7 +88,7 @@ class Activator {
NodeList activeToRemove = oldActive.matching(node -> ! hostnames.contains(node.hostname()));
remove(activeToRemove, transaction); // TODO: Pass activation time in this call and next line
- nodeRepository.nodes().activate(newActive.asList(), transaction.nested()); // activate also continued active to update node state
+ nodeRepository.nodes().activate(newActive.asList(), transaction); // activate also continued active to update node state
rememberResourceChange(transaction, generation, activationTime,
oldActive.not().retired(),
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java
index b7d6e0a9dd9..714374ccb8a 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java
@@ -176,7 +176,7 @@ public class MockNodeRepository extends NodeRepository {
.build());
// Ready all nodes, except 7 and 55
- nodes = nodes().addNodes(nodes, Agent.system);
+ nodes = new ArrayList<>(nodes().addNodes(nodes, Agent.system));
nodes.remove(node7);
nodes.remove(node55);
nodes = nodes().deallocate(nodes, Agent.system, getClass().getSimpleName());
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java
index 29ebf1789c0..9c843b3eb01 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java
@@ -141,7 +141,7 @@ public class RealDataScenarioTest {
if (nodeNext.get()) {
String json = input.substring(input.indexOf("{\""), input.lastIndexOf('}') + 1);
Node node = nodeSerializer.fromJson(json.getBytes(UTF_8));
- nodeRepository.database().addNodesInState(List.of(node), node.state(), Agent.system);
+ nodeRepository.database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system);
nodeNext.set(false);
} else {
if (!zkNodePathPattern.matcher(input).matches()) return;
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java
index f8ec271ce5f..523feeeb303 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java
@@ -23,6 +23,7 @@ import com.yahoo.test.ManualClock;
import com.yahoo.vespa.curator.Curator;
import com.yahoo.vespa.curator.mock.MockCurator;
import com.yahoo.vespa.flags.InMemoryFlagSource;
+import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.NodeRepository;
import com.yahoo.vespa.hosted.provision.autoscale.MemoryMetricsDb;
@@ -201,7 +202,7 @@ public class CapacityCheckerTester {
nodeRepository.nodes().addNodes(hostsWithChildren.getOrDefault(tenantHostApp, List.of()), Agent.system);
hostsWithChildren.forEach((applicationId, nodes) -> {
if (applicationId.equals(tenantHostApp)) return;
- nodeRepository.database().addNodesInState(nodes, Node.State.active, Agent.system);
+ nodeRepository.database().addNodesInState(new LockedNodeList(nodes, () -> { }), Node.State.active, Agent.system);
});
nodeRepository.nodes().addNodes(createEmptyHosts(numHosts, numEmptyHosts, emptyHostExcessCapacity, emptyHostExcessIps), Agent.system);
@@ -322,9 +323,9 @@ public class CapacityCheckerTester {
}
}
- nodeRepository.database().addNodesInState(hosts, Node.State.active, Agent.system);
+ nodeRepository.database().addNodesInState(new LockedNodeList(hosts, () -> { }), Node.State.active, Agent.system);
nodes.forEach((application, applicationNodes) -> {
- nodeRepository.database().addNodesInState(applicationNodes, Node.State.active, Agent.system);
+ nodeRepository.database().addNodesInState(new LockedNodeList(applicationNodes, () -> { }), Node.State.active, Agent.system);
});
updateCapacityChecker();
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java
index ddd7413567a..262616d5eac 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java
@@ -6,6 +6,7 @@ import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.Flavor;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
+import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.node.Agent;
import com.yahoo.vespa.hosted.provision.node.Allocation;
@@ -45,7 +46,7 @@ public class DirtyExpirerTest {
false))
.build();
- tester.nodeRepository().database().addNodesInState(List.of(node), node.state(), Agent.system);
+ tester.nodeRepository().database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system);
Duration expiryTimeout = Duration.ofMinutes(30);
DirtyExpirer expirer = new DirtyExpirer(tester.nodeRepository(), expiryTimeout, new TestMetric());
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java
index 925a34c0419..c16ed47a216 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java
@@ -27,6 +27,7 @@ import com.yahoo.test.ManualClock;
import com.yahoo.vespa.flags.InMemoryFlagSource;
import com.yahoo.vespa.flags.PermanentFlags;
import com.yahoo.vespa.flags.custom.ClusterCapacity;
+import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.Node.State;
import com.yahoo.vespa.hosted.provision.NodeList;
@@ -750,7 +751,7 @@ public class HostCapacityMaintainerTest {
createNode("host4", Optional.empty(), NodeType.host, Node.State.provisioned, null),
createNode("host4-1", Optional.of("host4"), NodeType.tenant, Node.State.reserved, tenantApp),
createNode("host4-2", Optional.of("host4"), NodeType.tenant, Node.State.reserved, tenantApp))
- .forEach(node -> nodeRepository.database().addNodesInState(List.of(node), node.state(), Agent.system));
+ .forEach(node -> nodeRepository.database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system));
return this;
}
@@ -772,7 +773,7 @@ public class HostCapacityMaintainerTest {
private Node addNode(String hostname, Optional<String> parentHostname, NodeType nodeType, Node.State state, ApplicationId application, Duration hostTTL) {
Node node = createNode(hostname, parentHostname, nodeType, state, application, hostTTL);
- return nodeRepository.database().addNodesInState(List.of(node), node.state(), Agent.system).get(0);
+ return nodeRepository.database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system).get(0);
}
private Node createNode(String hostname, Optional<String> parentHostname, NodeType nodeType,
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java
index 7d90de1ccaf..83aea78ce58 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java
@@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.provision.maintenance;
import com.yahoo.component.Version;
import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.ApplicationTransaction;
import com.yahoo.config.provision.Capacity;
import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.ClusterResources;
@@ -10,11 +11,13 @@ import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.NodeFlavors;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
+import com.yahoo.config.provision.ProvisionLock;
import com.yahoo.jdisc.Metric;
import com.yahoo.transaction.Mutex;
import com.yahoo.transaction.NestedTransaction;
import com.yahoo.vespa.applicationmodel.ApplicationInstance;
import com.yahoo.vespa.applicationmodel.ApplicationInstanceReference;
+import com.yahoo.vespa.applicationmodel.InfrastructureApplication;
import com.yahoo.vespa.curator.stats.LockStats;
import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
@@ -210,7 +213,8 @@ public class MetricsReporterTest {
}
NestedTransaction transaction = new NestedTransaction();
- nodeRepository.nodes().activate(nodeRepository.nodes().list().nodeType(NodeType.host).asList(), transaction);
+ nodeRepository.nodes().activate(nodeRepository.nodes().list().nodeType(NodeType.host).asList(),
+ new ApplicationTransaction(new ProvisionLock(InfrastructureApplication.TENANT_HOST.id(), () -> { }), transaction));
transaction.commit();
Orchestrator orchestrator = mock(Orchestrator.class);
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java
index 359f75c27ab..ac1e452d7a5 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java
@@ -9,6 +9,7 @@ import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.RegionName;
import com.yahoo.config.provision.SystemName;
import com.yahoo.config.provision.Zone;
+import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.node.Agent;
import com.yahoo.vespa.hosted.provision.provisioning.ProvisioningTester;
@@ -45,7 +46,7 @@ public class ProvisionedExpirerTest {
var nodes = IntStream.range(0, 15)
.mapToObj(i -> Node.create("id-" + i, "host-" + i, new Flavor(NodeResources.unspecified()), Node.State.provisioned, NodeType.host).build())
.toList();
- tester.nodeRepository().database().addNodesInState(nodes, Node.State.provisioned, Agent.system);
+ tester.nodeRepository().database().addNodesInState(new LockedNodeList(nodes, () -> { }), Node.State.provisioned, Agent.system);
}
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java
index b54975cbf41..a5ac2be72ee 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.hosted.provision.maintenance;
import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.ApplicationTransaction;
import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.DockerImage;
@@ -10,6 +11,7 @@ import com.yahoo.config.provision.Flavor;
import com.yahoo.config.provision.NodeFlavors;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
+import com.yahoo.config.provision.ProvisionLock;
import com.yahoo.config.provision.RegionName;
import com.yahoo.config.provision.Zone;
import com.yahoo.test.ManualClock;
@@ -313,7 +315,7 @@ public class SpareCapacityMaintainerTest {
}
private void allocate(ApplicationId application, ClusterSpec clusterSpec, List<Node> nodes) {
- nodes = nodeRepository.nodes().addNodes(nodes, Agent.system);
+ nodes = new ArrayList<>(nodeRepository.nodes().addNodes(nodes, Agent.system));
for (int i = 0; i < nodes.size(); i++) {
Node node = nodes.get(i);
ClusterMembership membership = ClusterMembership.from(clusterSpec, i);
@@ -322,7 +324,7 @@ public class SpareCapacityMaintainerTest {
}
nodes = nodeRepository.nodes().reserve(nodes);
var transaction = new NestedTransaction();
- nodes = nodeRepository.nodes().activate(nodes, transaction);
+ nodes = nodeRepository.nodes().activate(nodes, new ApplicationTransaction(new ProvisionLock(application, () -> { }), transaction));
transaction.commit();
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java
index 47d34a76dd6..478b201d71b 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.hosted.provision.provisioning;
import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.ApplicationTransaction;
import com.yahoo.config.provision.Capacity;
import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.ClusterResources;
@@ -12,6 +13,7 @@ import com.yahoo.config.provision.HostSpec;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.NodeAllocationException;
+import com.yahoo.config.provision.ProvisionLock;
import com.yahoo.config.provision.RegionName;
import com.yahoo.config.provision.SystemName;
import com.yahoo.config.provision.Zone;
@@ -540,9 +542,9 @@ public class DynamicAllocationTest {
clusterSpec.with(Optional.of(ClusterSpec.Group.from(0))), index); // Need to add group here so that group is serialized in node allocation
Node node1aAllocation = node1a.allocate(id, clusterMembership1, node1a.resources(), Instant.now());
- tester.nodeRepository().nodes().addNodes(Collections.singletonList(node1aAllocation), Agent.system);
+ tester.nodeRepository().nodes().addNodes(List.of(node1aAllocation), Agent.system);
NestedTransaction transaction = new NestedTransaction().add(new CuratorTransaction(tester.getCurator()));
- tester.nodeRepository().nodes().activate(Collections.singletonList(node1aAllocation), transaction);
+ tester.nodeRepository().nodes().activate(List.of(node1aAllocation), new ApplicationTransaction(new ProvisionLock(id, () -> { }), transaction));
transaction.commit();
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java
index 2acbeb00f5f..dd8f97d82de 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java
@@ -214,7 +214,7 @@ public class ProvisioningTester {
NestedTransaction t = new NestedTransaction();
if (parent.ipConfig().primary().isEmpty())
parent = parent.with(IP.Config.of(Set.of("::" + 0 + ":0"), Set.of("::" + 0 + ":2")));
- nodeRepository.nodes().activate(List.of(parent), t);
+ nodeRepository.nodes().activate(List.of(parent), new ApplicationTransaction(new ProvisionLock(application, () -> { }), t));
t.commit();
}
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h
index 307d1a0d112..25ca7729a32 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h
+++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h
@@ -79,11 +79,12 @@ public:
static vespalib::datastore::ArrayStoreConfig optimizedConfigForHugePage(size_t max_type_id,
- size_t hugePageSize,
- size_t smallPageSize,
- size_t min_num_entries_for_new_buffer,
- float allocGrowFactor,
- bool enable_free_lists);
+ size_t hugePageSize,
+ size_t smallPageSize,
+ size_t max_buffer_size,
+ size_t min_num_entries_for_new_buffer,
+ float allocGrowFactor,
+ bool enable_free_lists);
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp
index 99808b11e92..3c9a52f2e5c 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp
@@ -66,14 +66,15 @@ MultiValueMapping<ElemT, RefT>::getAddressSpaceUsage() const {
template <typename ElemT, typename RefT>
vespalib::datastore::ArrayStoreConfig
MultiValueMapping<ElemT, RefT>::optimizedConfigForHugePage(size_t max_type_id,
- size_t hugePageSize,
- size_t smallPageSize,
- size_t min_num_entries_for_new_buffer,
- float allocGrowFactor,
- bool enable_free_lists)
+ size_t hugePageSize,
+ size_t smallPageSize,
+ size_t max_buffer_size,
+ size_t min_num_entries_for_new_buffer,
+ float allocGrowFactor,
+ bool enable_free_lists)
{
ArrayStoreTypeMapper mapper(max_type_id, array_store_grow_factor);
- auto result = ArrayStore::optimizedConfigForHugePage(max_type_id, mapper, hugePageSize, smallPageSize, min_num_entries_for_new_buffer, allocGrowFactor);
+ auto result = ArrayStore::optimizedConfigForHugePage(max_type_id, mapper, hugePageSize, smallPageSize, max_buffer_size, min_num_entries_for_new_buffer, allocGrowFactor);
result.enable_free_lists(enable_free_lists);
return result;
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp
index d8ada97fa2c..3cf75b450af 100644
--- a/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp
@@ -28,6 +28,7 @@ MultiValueAttribute(const vespalib::string &baseFileName,
_mvMapping(MultiValueMapping::optimizedConfigForHugePage(MultiValueMapping::array_store_max_type_id,
vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ vespalib::datastore::ArrayStoreConfig::default_max_buffer_size,
8 * 1024,
cfg.getGrowStrategy().getMultiValueAllocGrowFactor(),
multivalueattribute::enable_free_lists),
diff --git a/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp b/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp
index cd9e0508344..00c195b9eb7 100644
--- a/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp
@@ -20,6 +20,7 @@ RawBufferStore::RawBufferStore(std::shared_ptr<vespalib::alloc::MemoryAllocator>
TypeMapper(max_small_buffer_type_id, grow_factor),
MemoryAllocator::HUGEPAGE_SIZE,
MemoryAllocator::PAGE_SIZE,
+ vespalib::datastore::ArrayStoreConfig::default_max_buffer_size,
8_Ki, ALLOC_GROW_FACTOR),
std::move(allocator), TypeMapper(max_small_buffer_type_id, grow_factor))
{
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 748a747d515..22a33270a27 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -31,6 +31,7 @@ namespace search::tensor {
using search::AddressSpaceComponents;
using search::StateExplorerUtils;
using search::queryeval::GlobalFilter;
+using vespalib::datastore::ArrayStoreConfig;
using vespalib::datastore::CompactionStrategy;
using vespalib::datastore::EntryRef;
using vespalib::GenericHeader;
@@ -145,25 +146,27 @@ PreparedAddDoc::PreparedAddDoc(PreparedAddDoc&& other) noexcept = default;
}
template <HnswIndexType type>
-vespalib::datastore::ArrayStoreConfig
+ArrayStoreConfig
HnswIndex<type>::make_default_level_array_store_config()
{
return LevelArrayStore::optimizedConfigForHugePage(max_level_array_size,
- vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
- vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
+ vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ ArrayStoreConfig::default_max_buffer_size,
min_num_arrays_for_new_buffer,
alloc_grow_factor).enable_free_lists(true);
}
template <HnswIndexType type>
-vespalib::datastore::ArrayStoreConfig
+ArrayStoreConfig
HnswIndex<type>::make_default_link_array_store_config()
{
return LinkArrayStore::optimizedConfigForHugePage(max_link_array_size,
- vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
- vespalib::alloc::MemoryAllocator::PAGE_SIZE,
- min_num_arrays_for_new_buffer,
- alloc_grow_factor).enable_free_lists(true);
+ vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
+ vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ ArrayStoreConfig::default_max_buffer_size,
+ min_num_arrays_for_new_buffer,
+ alloc_grow_factor).enable_free_lists(true);
}
template <HnswIndexType type>
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp
index a78d9cefc64..cf30d62a0b8 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp
@@ -49,6 +49,7 @@ HnswNodeidMapping::HnswNodeidMapping()
_nodeids(NodeidStore::optimizedConfigForHugePage(max_type_id,
vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ vespalib::datastore::ArrayStoreConfig::default_max_buffer_size,
min_num_arrays_for_new_buffer,
alloc_grow_factor).enable_free_lists(true), {}),
_hold_list(),
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp
index ff39c33fc5d..29f20e27d09 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp
@@ -38,6 +38,7 @@ TensorBufferStore::TensorBufferStore(const ValueType& tensor_type, std::shared_p
TensorBufferTypeMapper(max_small_subspaces_type_id, mapper_grow_factor, &_ops),
MemoryAllocator::HUGEPAGE_SIZE,
MemoryAllocator::PAGE_SIZE,
+ vespalib::datastore::ArrayStoreConfig::default_max_buffer_size,
8_Ki, ALLOC_GROW_FACTOR),
std::move(allocator), TensorBufferTypeMapper(max_small_subspaces_type_id, mapper_grow_factor, &_ops))
{
diff --git a/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp b/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp
index baec5494b36..efa7e18aa33 100644
--- a/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp
+++ b/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp
@@ -40,6 +40,7 @@ vespalib::datastore::ArrayStoreConfig make_default_array_store_config() {
return ReplicaStore::optimizedConfigForHugePage(1023,
vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ vespalib::datastore::ArrayStoreConfig::default_max_buffer_size,
8_Ki, 0.2).enable_free_lists(true);
}
diff --git a/vespalib/src/tests/datastore/array_store/array_store_test.cpp b/vespalib/src/tests/datastore/array_store/array_store_test.cpp
index 2674acf1ce9..797dc97c963 100644
--- a/vespalib/src/tests/datastore/array_store/array_store_test.cpp
+++ b/vespalib/src/tests/datastore/array_store/array_store_test.cpp
@@ -578,6 +578,7 @@ struct ByteStoreTest : public ArrayStoreTest<testing::Test, uint8_t, EntryRefT<1
optimizedConfigForHugePage(1023,
vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ ArrayStoreConfig::default_max_buffer_size,
8_Ki, ALLOC_GROW_FACTOR)) {}
};
diff --git a/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp b/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp
index 71c1341ae74..3bcc130052d 100644
--- a/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp
+++ b/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp
@@ -22,15 +22,20 @@ struct Fixture
Fixture(uint32_t max_type_id,
size_t hugePageSize,
size_t smallPageSize,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer)
: cfg(ArrayStoreConfig::optimizeForHugePage(max_type_id,
[](size_t type_id) noexcept { return type_id * sizeof(int); },
hugePageSize, smallPageSize,
EntryRefType::offsetSize(),
+ max_buffer_size,
min_num_entries_for_new_buffer,
ALLOC_GROW_FACTOR)) { }
void assertSpec(uint32_t type_id, uint32_t num_entries_for_new_buffer) {
- assertSpec(type_id, AllocSpec(0, EntryRefType::offsetSize(),
+ assertSpec(type_id, EntryRefType::offsetSize(), num_entries_for_new_buffer);
+ }
+ void assertSpec(uint32_t type_id, uint32_t max_entries, uint32_t num_entries_for_new_buffer) {
+ assertSpec(type_id, AllocSpec(0, max_entries,
num_entries_for_new_buffer, ALLOC_GROW_FACTOR));
}
void assertSpec(uint32_t type_id, const AllocSpec &expSpec) {
@@ -50,9 +55,6 @@ makeSpec(size_t min_entries_in_buffer,
return AllocSpec(min_entries_in_buffer, max_entries_in_buffer, num_entries_for_new_buffer, ALLOC_GROW_FACTOR);
}
-constexpr size_t KB = 1024;
-constexpr size_t MB = KB * KB;
-
TEST_F("require that default allocation spec is given for all array sizes", Fixture(3, makeSpec(4, 32, 8)))
{
EXPECT_EQUAL(3u, f.cfg.max_type_id());
@@ -62,26 +64,54 @@ TEST_F("require that default allocation spec is given for all array sizes", Fixt
TEST_DO(f.assertSpec(3, makeSpec(4, 32, 8)));
}
-TEST_F("require that we can generate config optimized for a given huge page", Fixture(1024,
- 2 * MB,
- 4 * KB,
- 8 * KB))
+struct BigBuffersFixture : public Fixture {
+ BigBuffersFixture() : Fixture(1023, 2_Mi, 4_Ki, 1024_Gi, 8_Ki) { }
+};
+
+TEST_F("require that we can generate config optimized for a given huge page without capped buffer sizes", BigBuffersFixture())
+{
+ EXPECT_EQUAL(1023u, f.cfg.max_type_id());
+ TEST_DO(f.assertSpec(0, 8_Ki)); // large arrays
+ TEST_DO(f.assertSpec(1, 256_Ki));
+ TEST_DO(f.assertSpec(2, 256_Ki));
+ TEST_DO(f.assertSpec(3, 168_Ki));
+ TEST_DO(f.assertSpec(4, 128_Ki));
+ TEST_DO(f.assertSpec(5, 100_Ki));
+ TEST_DO(f.assertSpec(6, 84_Ki));
+
+ TEST_DO(f.assertSpec(32, 16_Ki));
+ TEST_DO(f.assertSpec(33, 12_Ki));
+ TEST_DO(f.assertSpec(42, 12_Ki));
+ TEST_DO(f.assertSpec(43, 8_Ki));
+ TEST_DO(f.assertSpec(1022, 8_Ki));
+ TEST_DO(f.assertSpec(1023, 8_Ki));
+}
+
+struct CappedBuffersFixture : public Fixture {
+ CappedBuffersFixture() : Fixture(1023, 2_Mi, 4_Ki, 256_Mi, 8_Ki) { }
+ size_t max_entries(size_t array_size) {
+ auto entry_size = array_size * sizeof(int);
+ return (256_Mi + entry_size - 1) / entry_size;
+ }
+};
+
+TEST_F("require that we can generate config optimized for a given huge page with capped buffer sizes", CappedBuffersFixture())
{
- EXPECT_EQUAL(1_Ki, f.cfg.max_type_id());
- TEST_DO(f.assertSpec(0, 8 * KB)); // large arrays
- TEST_DO(f.assertSpec(1, 256 * KB));
- TEST_DO(f.assertSpec(2, 256 * KB));
- TEST_DO(f.assertSpec(3, 168 * KB));
- TEST_DO(f.assertSpec(4, 128 * KB));
- TEST_DO(f.assertSpec(5, 100 * KB));
- TEST_DO(f.assertSpec(6, 84 * KB));
+ EXPECT_EQUAL(1023u, f.cfg.max_type_id());
+ TEST_DO(f.assertSpec(0, f.max_entries(1023), 8_Ki)); // large arrays
+ TEST_DO(f.assertSpec(1, 256_Ki));
+ TEST_DO(f.assertSpec(2, 256_Ki));
+ TEST_DO(f.assertSpec(3, 168_Ki));
+ TEST_DO(f.assertSpec(4, 128_Ki));
+ TEST_DO(f.assertSpec(5, 100_Ki));
+ TEST_DO(f.assertSpec(6, 84_Ki));
- TEST_DO(f.assertSpec(32, 16 * KB));
- TEST_DO(f.assertSpec(33, 12 * KB));
- TEST_DO(f.assertSpec(42, 12 * KB));
- TEST_DO(f.assertSpec(43, 8 * KB));
- TEST_DO(f.assertSpec(1022, 8 * KB));
- TEST_DO(f.assertSpec(1023, 8 * KB));
+ TEST_DO(f.assertSpec(32, 16_Ki));
+ TEST_DO(f.assertSpec(33, 12_Ki));
+ TEST_DO(f.assertSpec(42, 12_Ki));
+ TEST_DO(f.assertSpec(43, 8_Ki));
+ TEST_DO(f.assertSpec(1022, f.max_entries(1022), 8_Ki));
+ TEST_DO(f.assertSpec(1023, f.max_entries(1023), 8_Ki));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/vespalib/src/vespa/vespalib/datastore/array_store.h b/vespalib/src/vespa/vespalib/datastore/array_store.h
index 0490687aeb8..7ee63be3848 100644
--- a/vespalib/src/vespa/vespalib/datastore/array_store.h
+++ b/vespalib/src/vespa/vespalib/datastore/array_store.h
@@ -196,6 +196,7 @@ public:
static ArrayStoreConfig optimizedConfigForHugePage(uint32_t max_type_id,
size_t hugePageSize,
size_t smallPageSize,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer,
float allocGrowFactor);
@@ -203,6 +204,7 @@ public:
const TypeMapper& mapper,
size_t hugePageSize,
size_t smallPageSize,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer,
float allocGrowFactor);
};
diff --git a/vespalib/src/vespa/vespalib/datastore/array_store.hpp b/vespalib/src/vespa/vespalib/datastore/array_store.hpp
index 211176b8ad0..bfd4ff0430a 100644
--- a/vespalib/src/vespa/vespalib/datastore/array_store.hpp
+++ b/vespalib/src/vespa/vespalib/datastore/array_store.hpp
@@ -252,6 +252,7 @@ ArrayStoreConfig
ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_type_id,
size_t hugePageSize,
size_t smallPageSize,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer,
float allocGrowFactor)
{
@@ -260,6 +261,7 @@ ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_ty
mapper,
hugePageSize,
smallPageSize,
+ max_buffer_size,
min_num_entries_for_new_buffer,
allocGrowFactor);
}
@@ -267,17 +269,19 @@ ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_ty
template <typename ElemT, typename RefT, typename TypeMapperT>
ArrayStoreConfig
ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_type_id,
- const TypeMapper& mapper,
- size_t hugePageSize,
- size_t smallPageSize,
- size_t min_num_entries_for_new_buffer,
- float allocGrowFactor)
+ const TypeMapper& mapper,
+ size_t hugePageSize,
+ size_t smallPageSize,
+ size_t max_buffer_size,
+ size_t min_num_entries_for_new_buffer,
+ float allocGrowFactor)
{
return ArrayStoreConfig::optimizeForHugePage(mapper.get_max_type_id(max_type_id),
[&](uint32_t type_id) noexcept { return mapper.get_entry_size(type_id); },
hugePageSize,
smallPageSize,
RefT::offsetSize(),
+ max_buffer_size,
min_num_entries_for_new_buffer,
allocGrowFactor);
}
diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp b/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp
index c7f0b69a85e..37f6fab96dc 100644
--- a/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp
+++ b/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "array_store_config.h"
+#include <algorithm>
#include <cassert>
namespace vespalib::datastore {
@@ -42,6 +43,13 @@ alignToSmallPageSize(size_t value, size_t minLimit, size_t smallPageSize)
return ((value - minLimit) / smallPageSize) * smallPageSize + minLimit;
}
+size_t
+cap_max_entries(size_t max_entries, size_t max_buffer_size, size_t entry_size)
+{
+ size_t dynamic_max_entries = (max_buffer_size + (entry_size - 1)) / entry_size;
+ return std::min(max_entries, dynamic_max_entries);
+}
+
}
ArrayStoreConfig
@@ -50,17 +58,21 @@ ArrayStoreConfig::optimizeForHugePage(uint32_t max_type_id,
size_t hugePageSize,
size_t smallPageSize,
size_t maxEntryRefOffset,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer,
float allocGrowFactor)
{
AllocSpecVector allocSpecs;
- allocSpecs.emplace_back(0, maxEntryRefOffset, min_num_entries_for_new_buffer, allocGrowFactor); // large array spec;
+ auto entry_size = type_id_to_entry_size(max_type_id);
+ auto capped_max_entries = cap_max_entries(maxEntryRefOffset, max_buffer_size, entry_size);
+ allocSpecs.emplace_back(0, capped_max_entries, min_num_entries_for_new_buffer, allocGrowFactor); // large array spec;
for (uint32_t type_id = 1; type_id <= max_type_id; ++type_id) {
- size_t entry_size = type_id_to_entry_size(type_id);
+ entry_size = type_id_to_entry_size(type_id);
+ capped_max_entries = cap_max_entries(maxEntryRefOffset, max_buffer_size, entry_size);
size_t num_entries_for_new_buffer = hugePageSize / entry_size;
- num_entries_for_new_buffer = capToLimits(num_entries_for_new_buffer, min_num_entries_for_new_buffer, maxEntryRefOffset);
+ num_entries_for_new_buffer = capToLimits(num_entries_for_new_buffer, min_num_entries_for_new_buffer, capped_max_entries);
num_entries_for_new_buffer = alignToSmallPageSize(num_entries_for_new_buffer, min_num_entries_for_new_buffer, smallPageSize);
- allocSpecs.emplace_back(0, maxEntryRefOffset, num_entries_for_new_buffer, allocGrowFactor);
+ allocSpecs.emplace_back(0, capped_max_entries, num_entries_for_new_buffer, allocGrowFactor);
}
return ArrayStoreConfig(allocSpecs);
}
diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_config.h b/vespalib/src/vespa/vespalib/datastore/array_store_config.h
index 3b62609d0f1..3967996c64d 100644
--- a/vespalib/src/vespa/vespalib/datastore/array_store_config.h
+++ b/vespalib/src/vespa/vespalib/datastore/array_store_config.h
@@ -2,6 +2,7 @@
#pragma once
+#include <vespa/vespalib/util/size_literals.h>
#include <cstddef>
#include <cstdint>
#include <functional>
@@ -39,6 +40,8 @@ public:
using AllocSpecVector = std::vector<AllocSpec>;
+ static constexpr size_t default_max_buffer_size = 256_Mi;
+
private:
AllocSpecVector _allocSpecs;
bool _enable_free_lists;
@@ -77,6 +80,7 @@ public:
size_t hugePageSize,
size_t smallPageSize,
size_t maxEntryRefOffset,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer,
float allocGrowFactor);
};