aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--airlift-zstd/src/test/java/ai/vespa/airlift/zstd/TestCompressor.java1
-rw-r--r--client/go/internal/cli/auth/auth0/auth0.go38
-rw-r--r--client/go/internal/cli/auth/zts/zts.go28
-rw-r--r--client/go/internal/cli/auth/zts/zts_test.go7
-rw-r--r--client/go/internal/cli/cmd/cert.go9
-rw-r--r--client/go/internal/cli/cmd/config.go91
-rw-r--r--client/go/internal/cli/cmd/config_test.go119
-rw-r--r--client/go/internal/cli/cmd/curl.go12
-rw-r--r--client/go/internal/cli/cmd/feed.go92
-rw-r--r--client/go/internal/cli/cmd/root.go98
-rw-r--r--client/go/internal/cli/cmd/test.go2
-rw-r--r--client/go/internal/cli/cmd/testutil_test.go21
-rw-r--r--client/go/internal/util/http.go29
-rw-r--r--client/go/internal/vespa/crypto.go2
-rw-r--r--client/go/internal/vespa/deploy.go24
-rw-r--r--client/go/internal/vespa/deploy_test.go4
-rw-r--r--client/go/internal/vespa/document/dispatcher.go141
-rw-r--r--client/go/internal/vespa/document/dispatcher_test.go6
-rw-r--r--client/go/internal/vespa/document/document.go24
-rw-r--r--client/go/internal/vespa/document/http.go63
-rw-r--r--client/go/internal/vespa/document/http_test.go52
-rw-r--r--client/go/internal/vespa/document/queue.go43
-rw-r--r--client/go/internal/vespa/document/queue_test.go29
-rw-r--r--client/go/internal/vespa/target.go76
-rw-r--r--client/go/internal/vespa/target_cloud.go109
-rw-r--r--client/go/internal/vespa/target_custom.go25
-rw-r--r--client/go/internal/vespa/target_test.go46
-rw-r--r--client/pom.xml6
-rw-r--r--cloud-tenant-base-dependencies-enforcer/pom.xml15
-rw-r--r--clustercontroller-core/pom.xml5
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java8
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java151
-rw-r--r--clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeCheckerTest.java407
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java1
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/ClusterControllerConfig.java21
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/NodeResourcesTuning.java10
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/content/FleetControllerClusterTest.java3
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/search/NodeResourcesTuningTest.java12
-rw-r--r--configdefinitions/src/vespa/fleetcontroller.def15
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java36
-rw-r--r--container-search/abi-spec.json3
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/hitfield/RawBase64.java21
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/query/MultiRangeItem.java2
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/query/MultiTermItem.java4
-rw-r--r--container-search/src/main/java/com/yahoo/search/grouping/result/BucketGroupId.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/grouping/result/HitRenderer.java24
-rw-r--r--container-search/src/main/java/com/yahoo/search/grouping/result/RawBucketId.java6
-rw-r--r--container-search/src/main/java/com/yahoo/search/grouping/result/RawId.java6
-rw-r--r--container-search/src/main/java/com/yahoo/search/grouping/vespa/ResultBuilder.java25
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/compiled/Binding.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java25
-rw-r--r--container-search/src/test/java/com/yahoo/search/grouping/result/GroupIdTestCase.java11
-rw-r--r--container-search/src/test/java/com/yahoo/search/grouping/result/HitRendererTestCase.java4
-rw-r--r--container-search/src/test/java/com/yahoo/search/grouping/vespa/ResultBuilderTestCase.java92
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/ApplicationVersion.java34
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Change.java76
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/InstanceList.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java13
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java20
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java56
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/RevisionHistory.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/Versions.java5
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java20
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java22
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java4
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/deployment/DeploymentApiHandler.java4
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java62
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java34
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/UpgraderTest.java20
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java12
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java42
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment-overview.json25
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/deployment/responses/root.json10
-rw-r--r--document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java7
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/DocumentUpdateFlags.java2
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/XmlSerializationHelper.java6
-rw-r--r--document/src/test/java/com/yahoo/document/DocumentTestCase.java2
-rw-r--r--document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java2
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java2
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonWriterTestCase.java2
-rw-r--r--fat-model-dependencies/pom.xml4
-rw-r--r--flags/src/main/java/com/yahoo/vespa/flags/Flags.java6
-rw-r--r--jdisc_core/src/test/java/com/yahoo/jdisc/core/ExportPackagesIT.java12
-rw-r--r--model-integration/pom.xml17
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java18
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/Completion.java41
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/Generator.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/LanguageModel.java18
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/Prompt.java23
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/StringPrompt.java43
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java84
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java44
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java14
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java4
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def4
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/CompletionTest.java37
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/reports/DropDocumentsReport.java55
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java22
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java24
-rw-r--r--node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImplTest.java55
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityChecker.java2
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java4
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java61
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/NodesV2ApiTest.java22
-rw-r--r--parent/pom.xml17
-rw-r--r--predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/ResultMetrics.java4
-rw-r--r--predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java5
-rw-r--r--searchcore/src/tests/proton/attribute/attribute_initializer/attribute_initializer_test.cpp6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/expression/Int16ResultNode.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/expression/Int32ResultNode.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNode.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/expression/RawResultNode.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/ranking/features/fieldmatch/FieldMatchMetrics.java88
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/gbdt/GbdtConverterTestCase.java58
-rw-r--r--searchlib/src/tests/attribute/attribute_test.cpp28
-rw-r--r--searchlib/src/tests/attribute/enumeratedsave/enumeratedsave_test.cpp21
-rw-r--r--searchlib/src/tests/attribute/enumstore/enumstore_test.cpp77
-rw-r--r--searchlib/src/tests/query/streaming_query_test.cpp37
-rw-r--r--searchlib/src/vespa/searchcommon/common/undefinedvalues.h4
-rw-r--r--searchlib/src/vespa/searchlib/attribute/attributevector.cpp15
-rw-r--r--searchlib/src/vespa/searchlib/attribute/enum_store_loaders.cpp8
-rw-r--r--searchlib/src/vespa/searchlib/attribute/enum_store_loaders.h1
-rw-r--r--searchlib/src/vespa/searchlib/attribute/enumattribute.h3
-rw-r--r--searchlib/src/vespa/searchlib/attribute/enumattribute.hpp14
-rw-r--r--searchlib/src/vespa/searchlib/attribute/enumstore.h10
-rw-r--r--searchlib/src/vespa/searchlib/attribute/enumstore.hpp46
-rw-r--r--searchlib/src/vespa/searchlib/attribute/i_enum_store.h2
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multinumericenumattribute.hpp4
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multistringattribute.hpp1
-rw-r--r--searchlib/src/vespa/searchlib/attribute/postinglistattribute.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singleenumattribute.h4
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp22
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp3
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.h2
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp10
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp8
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp1
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp7
-rw-r--r--searchlib/src/vespa/searchlib/attribute/stringbase.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/attribute/stringbase.h2
-rw-r--r--searchlib/src/vespa/searchlib/query/query_term_simple.h4
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp36
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h35
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.cpp20
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.h3
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/queryterm.h5
-rw-r--r--security-utils/src/main/java/com/yahoo/security/SideChannelSafe.java2
-rw-r--r--security-utils/src/test/java/com/yahoo/security/SharedKeyTest.java6
-rw-r--r--storage/src/tests/persistence/persistencetestutils.h12
-rw-r--r--storage/src/tests/persistence/testandsettest.cpp81
-rw-r--r--storage/src/tests/storageapi/mbusprot/storageprotocoltest.cpp2
-rw-r--r--storage/src/vespa/storage/persistence/asynchandler.cpp6
-rw-r--r--storage/src/vespa/storage/persistence/persistencehandler.cpp2
-rw-r--r--storage/src/vespa/storage/persistence/simplemessagehandler.cpp33
-rw-r--r--storage/src/vespa/storage/persistence/simplemessagehandler.h13
-rw-r--r--storage/src/vespa/storage/persistence/testandsethelper.cpp84
-rw-r--r--storage/src/vespa/storage/persistence/testandsethelper.h52
-rw-r--r--storage/src/vespa/storageapi/message/persistence.cpp6
-rw-r--r--storage/src/vespa/storageapi/message/persistence.h15
-rw-r--r--vdslib/src/main/java/com/yahoo/vdslib/distribution/Distribution.java4
-rw-r--r--vespa-dependencies-enforcer/allowed-maven-dependencies.txt10
-rw-r--r--vespa-feed-client-api/pom.xml6
-rw-r--r--vespa-feed-client-cli/pom.xml6
-rw-r--r--vespa-feed-client/pom.xml6
-rw-r--r--vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java69
-rw-r--r--vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/slime/BinaryEncoder.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/io/FatalErrorHandlerTestCase.java58
175 files changed, 3038 insertions, 1322 deletions
diff --git a/airlift-zstd/src/test/java/ai/vespa/airlift/zstd/TestCompressor.java b/airlift-zstd/src/test/java/ai/vespa/airlift/zstd/TestCompressor.java
index d6f13b98c71..4aa00f91ffc 100644
--- a/airlift-zstd/src/test/java/ai/vespa/airlift/zstd/TestCompressor.java
+++ b/airlift-zstd/src/test/java/ai/vespa/airlift/zstd/TestCompressor.java
@@ -20,6 +20,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+@SuppressWarnings("proprietary")
public class TestCompressor
{
@Test
diff --git a/client/go/internal/cli/auth/auth0/auth0.go b/client/go/internal/cli/auth/auth0/auth0.go
index 5f7612d4d2e..6fcd3f7680e 100644
--- a/client/go/internal/cli/auth/auth0/auth0.go
+++ b/client/go/internal/cli/auth/auth0/auth0.go
@@ -110,28 +110,40 @@ func (a *Client) getDeviceFlowConfig() (flowConfig, error) {
}
r, err := a.httpClient.Do(req, time.Second*30)
if err != nil {
- return flowConfig{}, fmt.Errorf("failed to get device flow config: %w", err)
+ return flowConfig{}, fmt.Errorf("auth0: failed to get device flow config: %w", err)
}
defer r.Body.Close()
if r.StatusCode/100 != 2 {
- return flowConfig{}, fmt.Errorf("failed to get device flow config: got response code %d from %s", r.StatusCode, url)
+ return flowConfig{}, fmt.Errorf("auth0: failed to get device flow config: got response code %d from %s", r.StatusCode, url)
}
var cfg flowConfig
if err := json.NewDecoder(r.Body).Decode(&cfg); err != nil {
- return flowConfig{}, fmt.Errorf("failed to decode response: %w", err)
+ return flowConfig{}, fmt.Errorf("auth0: failed to decode response: %w", err)
}
return cfg, nil
}
+func (a *Client) Authenticate(request *http.Request) error {
+ accessToken, err := a.AccessToken()
+ if err != nil {
+ return err
+ }
+ if request.Header == nil {
+ request.Header = make(http.Header)
+ }
+ request.Header.Set("Authorization", "Bearer "+accessToken)
+ return nil
+}
+
// AccessToken returns an access token for the configured system, refreshing it if necessary.
func (a *Client) AccessToken() (string, error) {
creds, ok := a.provider.Systems[a.options.SystemName]
if !ok {
- return "", fmt.Errorf("system %s is not configured", a.options.SystemName)
+ return "", fmt.Errorf("auth0: system %s is not configured: %s", a.options.SystemName, reauthMessage)
} else if creds.AccessToken == "" {
- return "", fmt.Errorf("access token missing: %s", reauthMessage)
+ return "", fmt.Errorf("auth0: access token missing: %s", reauthMessage)
} else if scopesChanged(creds) {
- return "", fmt.Errorf("authentication scopes changed: %s", reauthMessage)
+ return "", fmt.Errorf("auth0: authentication scopes changed: %s", reauthMessage)
} else if isExpired(creds.ExpiresAt, accessTokenExpiry) {
// check if the stored access token is expired:
// use the refresh token to get a new access token:
@@ -142,7 +154,7 @@ func (a *Client) AccessToken() (string, error) {
}
resp, err := tr.Refresh(cancelOnInterrupt(), a.options.SystemName)
if err != nil {
- return "", fmt.Errorf("failed to renew access token: %w: %s", err, reauthMessage)
+ return "", fmt.Errorf("auth0: failed to renew access token: %w: %s", err, reauthMessage)
} else {
// persist the updated system with renewed access token
creds.AccessToken = resp.AccessToken
@@ -173,12 +185,6 @@ func scopesChanged(s Credentials) bool {
return false
}
-// HasCredentials returns true if this client has retrived credentials for the configured system.
-func (a *Client) HasCredentials() bool {
- _, ok := a.provider.Systems[a.options.SystemName]
- return ok
-}
-
// WriteCredentials writes given credentials to the configuration file.
func (a *Client) WriteCredentials(credentials Credentials) error {
if a.provider.Systems == nil {
@@ -186,7 +192,7 @@ func (a *Client) WriteCredentials(credentials Credentials) error {
}
a.provider.Systems[a.options.SystemName] = credentials
if err := writeConfig(a.provider, a.options.ConfigPath); err != nil {
- return fmt.Errorf("failed to write config: %w", err)
+ return fmt.Errorf("auth0: failed to write config: %w", err)
}
return nil
}
@@ -195,11 +201,11 @@ func (a *Client) WriteCredentials(credentials Credentials) error {
func (a *Client) RemoveCredentials() error {
tr := &auth.TokenRetriever{Secrets: &auth.Keyring{}}
if err := tr.Delete(a.options.SystemName); err != nil {
- return fmt.Errorf("failed to remove system %s from secret storage: %w", a.options.SystemName, err)
+ return fmt.Errorf("auth0: failed to remove system %s from secret storage: %w", a.options.SystemName, err)
}
delete(a.provider.Systems, a.options.SystemName)
if err := writeConfig(a.provider, a.options.ConfigPath); err != nil {
- return fmt.Errorf("failed to write config: %w", err)
+ return fmt.Errorf("auth0: failed to write config: %w", err)
}
return nil
}
diff --git a/client/go/internal/cli/auth/zts/zts.go b/client/go/internal/cli/auth/zts/zts.go
index caa2d03367d..2c66ff13e8b 100644
--- a/client/go/internal/cli/auth/zts/zts.go
+++ b/client/go/internal/cli/auth/zts/zts.go
@@ -1,7 +1,6 @@
package zts
import (
- "crypto/tls"
"encoding/json"
"fmt"
"net/http"
@@ -18,26 +17,39 @@ const DefaultURL = "https://zts.athenz.ouroath.com:4443"
type Client struct {
client util.HTTPClient
tokenURL *url.URL
+ domain string
}
// NewClient creates a new client for an Athenz ZTS service located at serviceURL.
-func NewClient(client util.HTTPClient, serviceURL string) (*Client, error) {
+func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, error) {
tokenURL, err := url.Parse(serviceURL)
if err != nil {
return nil, err
}
tokenURL.Path = "/zts/v1/oauth2/token"
- return &Client{tokenURL: tokenURL, client: client}, nil
+ return &Client{tokenURL: tokenURL, client: client, domain: domain}, nil
}
-// AccessToken returns an access token within the given domain, using certificate to authenticate with ZTS.
-func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string, error) {
- data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", domain)
+func (c *Client) Authenticate(request *http.Request) error {
+ accessToken, err := c.AccessToken()
+ if err != nil {
+ return err
+ }
+ if request.Header == nil {
+ request.Header = make(http.Header)
+ }
+ request.Header.Add("Authorization", "Bearer "+accessToken)
+ return nil
+}
+
+// AccessToken returns an access token within the domain configured in client c.
+func (c *Client) AccessToken() (string, error) {
+ // TODO(mpolden): This should cache and re-use tokens until expiry
+ data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", c.domain)
req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(data))
if err != nil {
return "", err
}
- util.SetCertificates(c.client, []tls.Certificate{certificate})
response, err := c.client.Do(req, 10*time.Second)
if err != nil {
return "", err
@@ -45,7 +57,7 @@ func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
- return "", fmt.Errorf("got status %d from %s", response.StatusCode, c.tokenURL.String())
+ return "", fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String())
}
var ztsResponse struct {
AccessToken string `json:"access_token"`
diff --git a/client/go/internal/cli/auth/zts/zts_test.go b/client/go/internal/cli/auth/zts/zts_test.go
index d0cc7ea9f9d..1c75a94ee03 100644
--- a/client/go/internal/cli/auth/zts/zts_test.go
+++ b/client/go/internal/cli/auth/zts/zts_test.go
@@ -1,7 +1,6 @@
package zts
import (
- "crypto/tls"
"testing"
"github.com/vespa-engine/vespa/client/go/internal/mock"
@@ -9,17 +8,17 @@ import (
func TestAccessToken(t *testing.T) {
httpClient := mock.HTTPClient{}
- client, err := NewClient(&httpClient, "http://example.com")
+ client, err := NewClient(&httpClient, "vespa.vespa", "http://example.com")
if err != nil {
t.Fatal(err)
}
httpClient.NextResponseString(400, `{"message": "bad request"}`)
- _, err = client.AccessToken("vespa.vespa", tls.Certificate{})
+ _, err = client.AccessToken()
if err == nil {
t.Fatal("want error for non-ok response status")
}
httpClient.NextResponseString(200, `{"access_token": "foo bar"}`)
- token, err := client.AccessToken("vespa.vespa", tls.Certificate{})
+ token, err := client.AccessToken()
if err != nil {
t.Fatal(err)
}
diff --git a/client/go/internal/cli/cmd/cert.go b/client/go/internal/cli/cmd/cert.go
index 7f79a9db358..48bad974c3f 100644
--- a/client/go/internal/cli/cmd/cert.go
+++ b/client/go/internal/cli/cmd/cert.go
@@ -34,13 +34,18 @@ package specified as an argument to this command (default '.').
It's possible to override the private key and certificate used through
environment variables. This can be useful in continuous integration systems.
-Example of setting the certificate and key in-line:
+It's also possible override the CA certificate which can be useful when using self-signed certificates with a
+self-hosted Vespa service. See https://docs.vespa.ai/en/mtls.html for more information.
+Example of setting the CA certificate, certificate and key in-line:
+
+ export VESPA_CLI_DATA_PLANE_CA_CERT="my CA cert"
export VESPA_CLI_DATA_PLANE_CERT="my cert"
export VESPA_CLI_DATA_PLANE_KEY="my private key"
-Example of loading certificate and key from custom paths:
+Example of loading CA certificate, certificate and key from custom paths:
+ export VESPA_CLI_DATA_PLANE_CA_CERT_FILE=/path/to/cacert
export VESPA_CLI_DATA_PLANE_CERT_FILE=/path/to/cert
export VESPA_CLI_DATA_PLANE_KEY_FILE=/path/to/key
diff --git a/client/go/internal/cli/cmd/config.go b/client/go/internal/cli/cmd/config.go
index 2d32c454842..e2132814386 100644
--- a/client/go/internal/cli/cmd/config.go
+++ b/client/go/internal/cli/cmd/config.go
@@ -19,7 +19,6 @@ import (
"github.com/fatih/color"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
- "github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0"
"github.com/vespa-engine/vespa/client/go/internal/cli/config"
"github.com/vespa-engine/vespa/client/go/internal/vespa"
)
@@ -250,9 +249,10 @@ type Config struct {
}
type KeyPair struct {
- KeyPair tls.Certificate
- CertificateFile string
- PrivateKeyFile string
+ KeyPair tls.Certificate
+ RootCertificates []byte
+ CertificateFile string
+ PrivateKeyFile string
}
func loadConfig(environment map[string]string, flags map[string]*pflag.Flag) (*Config, error) {
@@ -392,6 +392,10 @@ func (c *Config) deploymentIn(system vespa.System) (vespa.Deployment, error) {
return vespa.Deployment{System: system, Application: app, Zone: zone}, nil
}
+func (c *Config) caCertificatePath() string {
+ return c.environment["VESPA_CLI_DATA_PLANE_CA_CERT_FILE"]
+}
+
func (c *Config) certificatePath(app vespa.ApplicationID, targetType string) (string, error) {
if override, ok := c.environment["VESPA_CLI_DATA_PLANE_CERT_FILE"]; ok {
return override, nil
@@ -412,50 +416,68 @@ func (c *Config) privateKeyPath(app vespa.ApplicationID, targetType string) (str
return c.applicationFilePath(app, "data-plane-private-key.pem")
}
-func (c *Config) x509KeyPair(app vespa.ApplicationID, targetType string) (KeyPair, error) {
+func (c *Config) readTLSOptions(app vespa.ApplicationID, targetType string) (vespa.TLSOptions, error) {
+ _, trustAll := c.environment["VESPA_CLI_DATA_PLANE_TRUST_ALL"]
cert, certOk := c.environment["VESPA_CLI_DATA_PLANE_CERT"]
key, keyOk := c.environment["VESPA_CLI_DATA_PLANE_KEY"]
- var (
- kp tls.Certificate
- err error
- certFile string
- keyFile string
- )
+ caCertText, caCertOk := c.environment["VESPA_CLI_DATA_PLANE_CA_CERT"]
+ options := vespa.TLSOptions{TrustAll: trustAll}
+ // CA certificate
+ if caCertOk {
+ options.CACertificate = []byte(caCertText)
+ } else {
+ caCertFile := c.caCertificatePath()
+ if caCertFile != "" {
+ b, err := os.ReadFile(caCertFile)
+ if err != nil {
+ return options, err
+ }
+ options.CACertificate = b
+ options.CACertificateFile = caCertFile
+ }
+ }
+ // Certificate and private key
if certOk && keyOk {
- // Use key pair from environment
- kp, err = tls.X509KeyPair([]byte(cert), []byte(key))
+ kp, err := tls.X509KeyPair([]byte(cert), []byte(key))
+ if err != nil {
+ return vespa.TLSOptions{}, err
+ }
+ options.KeyPair = []tls.Certificate{kp}
} else {
- keyFile, err = c.privateKeyPath(app, targetType)
+ keyFile, err := c.privateKeyPath(app, targetType)
if err != nil {
- return KeyPair{}, err
+ return vespa.TLSOptions{}, err
}
- certFile, err = c.certificatePath(app, targetType)
+ certFile, err := c.certificatePath(app, targetType)
if err != nil {
- return KeyPair{}, err
+ return vespa.TLSOptions{}, err
+ }
+ kp, err := tls.LoadX509KeyPair(certFile, keyFile)
+ if err == nil {
+ options.KeyPair = []tls.Certificate{kp}
+ options.PrivateKeyFile = keyFile
+ options.CertificateFile = certFile
+ } else if err != nil && !os.IsNotExist(err) {
+ return vespa.TLSOptions{}, err
}
- kp, err = tls.LoadX509KeyPair(certFile, keyFile)
- }
- if err != nil {
- return KeyPair{}, err
}
- if targetType == vespa.TargetHosted {
- cert, err := x509.ParseCertificate(kp.Certificate[0])
+ if options.KeyPair != nil {
+ cert, err := x509.ParseCertificate(options.KeyPair[0].Certificate[0])
if err != nil {
- return KeyPair{}, err
+ return vespa.TLSOptions{}, err
}
now := time.Now()
expiredAt := cert.NotAfter
if expiredAt.Before(now) {
delta := now.Sub(expiredAt).Truncate(time.Second)
- return KeyPair{}, fmt.Errorf("certificate %s expired at %s (%s ago)", certFile, cert.NotAfter, delta)
+ source := options.CertificateFile
+ if source == "" {
+ source = "environment"
+ }
+ return vespa.TLSOptions{}, fmt.Errorf("certificate in %s expired at %s (%s ago)", source, cert.NotAfter, delta)
}
- return KeyPair{KeyPair: kp, CertificateFile: certFile, PrivateKeyFile: keyFile}, nil
}
- return KeyPair{
- KeyPair: kp,
- CertificateFile: certFile,
- PrivateKeyFile: keyFile,
- }, nil
+ return options, nil
}
func (c *Config) apiKeyFileFromEnv() (string, bool) {
@@ -490,11 +512,10 @@ func (c *Config) readAPIKey(cli *CLI, system vespa.System, tenantName string) ([
return nil, nil // Vespa Cloud CI only talks to data plane and does not have an API key
}
if !cli.isCI() {
- client, err := cli.auth0Factory(cli.httpClient, auth0.Options{ConfigPath: c.authConfigPath(), SystemName: system.Name, SystemURL: system.URL})
- if err == nil && client.HasCredentials() {
- return nil, nil // use Auth0
+ if _, err := os.Stat(c.authConfigPath()); err == nil {
+ return nil, nil // We have auth config, so we should prefer Auth0 over API key
}
- cli.printWarning("Authenticating with API key. This is discouraged in non-CI environments", "Authenticate with 'vespa auth login'")
+ cli.printWarning("Authenticating with API key. This is discouraged in non-CI environments", "Authenticate with 'vespa auth login' instead")
}
return os.ReadFile(c.apiKeyPath(tenantName))
}
diff --git a/client/go/internal/cli/cmd/config_test.go b/client/go/internal/cli/cmd/config_test.go
index 458878b4356..66b65bf402b 100644
--- a/client/go/internal/cli/cmd/config_test.go
+++ b/client/go/internal/cli/cmd/config_test.go
@@ -2,15 +2,21 @@
package cmd
import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "math/big"
"os"
"path/filepath"
"testing"
+ "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0"
"github.com/vespa-engine/vespa/client/go/internal/mock"
- "github.com/vespa-engine/vespa/client/go/internal/util"
"github.com/vespa-engine/vespa/client/go/internal/vespa"
)
@@ -166,7 +172,7 @@ func TestReadAPIKey(t *testing.T) {
require.Nil(t, err)
assert.Equal(t, []byte("foo"), key)
- // Cloud CI does not read key from disk as it's not expected to have any
+ // Cloud CI never reads key from disk as it's not expected to have any
cli, _, _ = newTestCLI(t, "VESPA_CLI_CLOUD_CI=true")
key, err = cli.config.readAPIKey(cli, vespa.PublicSystem, "t1")
require.Nil(t, err)
@@ -186,12 +192,111 @@ func TestReadAPIKey(t *testing.T) {
require.Nil(t, err)
assert.Equal(t, []byte("baz"), key)
- // Auth0 is preferred when configured
+ // Prefer Auth0 if we have auth config
cli, _, _ = newTestCLI(t)
- cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error) {
- return &mockAuth0{hasCredentials: true}, nil
- }
+ require.Nil(t, os.WriteFile(filepath.Join(cli.config.homeDir, "auth.json"), []byte("foo"), 0600))
key, err = cli.config.readAPIKey(cli, vespa.PublicSystem, "t1")
require.Nil(t, err)
assert.Nil(t, key)
}
+
+func TestConfigReadTLSOptions(t *testing.T) {
+ app := vespa.ApplicationID{Tenant: "t1", Application: "a1", Instance: "i1"}
+ homeDir := t.TempDir()
+
+ // No environment variables, and no files on disk
+ assertTLSOptions(t, homeDir, app, vespa.TargetLocal, vespa.TLSOptions{})
+
+ // A single environment variable is set
+ assertTLSOptions(t, homeDir, app, vespa.TargetLocal, vespa.TLSOptions{TrustAll: true}, "VESPA_CLI_DATA_PLANE_TRUST_ALL=true")
+
+ // Key pair is provided in-line in environment variables
+ pemCert, pemKey, keyPair := createKeyPair(t)
+ assertTLSOptions(t, homeDir, app,
+ vespa.TargetLocal,
+ vespa.TLSOptions{
+ TrustAll: true,
+ CACertificate: []byte("cacert"),
+ KeyPair: []tls.Certificate{keyPair},
+ },
+ "VESPA_CLI_DATA_PLANE_TRUST_ALL=true",
+ "VESPA_CLI_DATA_PLANE_CA_CERT=cacert",
+ "VESPA_CLI_DATA_PLANE_CERT="+string(pemCert),
+ "VESPA_CLI_DATA_PLANE_KEY="+string(pemKey),
+ )
+
+ // Key pair is provided as file paths through environment variables
+ certFile := filepath.Join(homeDir, "cert")
+ keyFile := filepath.Join(homeDir, "key")
+ caCertFile := filepath.Join(homeDir, "cacert")
+ require.Nil(t, os.WriteFile(certFile, pemCert, 0600))
+ require.Nil(t, os.WriteFile(keyFile, pemKey, 0600))
+ require.Nil(t, os.WriteFile(caCertFile, []byte("cacert"), 0600))
+ assertTLSOptions(t, homeDir, app,
+ vespa.TargetLocal,
+ vespa.TLSOptions{
+ KeyPair: []tls.Certificate{keyPair},
+ CACertificate: []byte("cacert"),
+ CACertificateFile: caCertFile,
+ CertificateFile: certFile,
+ PrivateKeyFile: keyFile,
+ },
+ "VESPA_CLI_DATA_PLANE_CERT_FILE="+certFile,
+ "VESPA_CLI_DATA_PLANE_KEY_FILE="+keyFile,
+ "VESPA_CLI_DATA_PLANE_CA_CERT_FILE="+caCertFile,
+ )
+
+ // Key pair resides in default paths
+ defaultCertFile := filepath.Join(homeDir, app.String(), "data-plane-public-cert.pem")
+ defaultKeyFile := filepath.Join(homeDir, app.String(), "data-plane-private-key.pem")
+ require.Nil(t, os.WriteFile(defaultCertFile, pemCert, 0600))
+ require.Nil(t, os.WriteFile(defaultKeyFile, pemKey, 0600))
+ assertTLSOptions(t, homeDir, app,
+ vespa.TargetLocal,
+ vespa.TLSOptions{
+ KeyPair: []tls.Certificate{keyPair},
+ CertificateFile: defaultCertFile,
+ PrivateKeyFile: defaultKeyFile,
+ },
+ )
+}
+
+func assertTLSOptions(t *testing.T, homeDir string, app vespa.ApplicationID, target string, want vespa.TLSOptions, envVars ...string) {
+ t.Helper()
+ envVars = append(envVars, "VESPA_CLI_HOME="+homeDir)
+ cli, _, _ := newTestCLI(t, envVars...)
+ require.Nil(t, cli.Run("config", "set", "application", app.String()))
+ config, err := cli.config.readTLSOptions(app, vespa.TargetLocal)
+ require.Nil(t, err)
+ assert.Equal(t, want, config)
+}
+
+func createKeyPair(t *testing.T) ([]byte, []byte, tls.Certificate) {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatal(err)
+ }
+ notBefore := time.Now()
+ notAfter := notBefore.Add(24 * time.Hour)
+ template := x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{CommonName: "example.com"},
+ NotBefore: notBefore,
+ NotAfter: notAfter,
+ }
+ certificateDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certificateDER})
+ pemKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyDER})
+ kp, err := tls.X509KeyPair(pemCert, pemKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return pemCert, pemKey, kp
+}
diff --git a/client/go/internal/cli/cmd/curl.go b/client/go/internal/cli/cmd/curl.go
index 8fcd1fa6ef7..3d5aaff24dc 100644
--- a/client/go/internal/cli/cmd/curl.go
+++ b/client/go/internal/cli/cmd/curl.go
@@ -4,7 +4,6 @@ package cmd
import (
"fmt"
"log"
- "net/http"
"os"
"strings"
@@ -54,6 +53,7 @@ $ vespa curl -- -v --data-urlencode "yql=select * from music where album contain
return err
}
case vespa.DocumentService, vespa.QueryService:
+ c.CaCertificate = service.TLSOptions.CACertificateFile
c.PrivateKey = service.TLSOptions.PrivateKeyFile
c.Certificate = service.TLSOptions.CertificateFile
default:
@@ -79,15 +79,7 @@ func addAccessToken(cmd *curl.Command, target vespa.Target) error {
if target.Type() != vespa.TargetCloud {
return nil
}
- req := http.Request{}
- if err := target.SignRequest(&req, ""); err != nil {
- return err
- }
- headerValue := req.Header.Get("Authorization")
- if headerValue == "" {
- return fmt.Errorf("no authorization header added when signing request")
- }
- cmd.Header("Authorization", headerValue)
+ cmd.Header("Authorization", "secret")
return nil
}
diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go
index 895a22d2be5..06568dd35c3 100644
--- a/client/go/internal/cli/cmd/feed.go
+++ b/client/go/internal/cli/cmd/feed.go
@@ -6,6 +6,7 @@ import (
"io"
"math"
"os"
+ "runtime/pprof"
"time"
"github.com/spf13/cobra"
@@ -14,16 +15,35 @@ import (
"github.com/vespa-engine/vespa/client/go/internal/vespa/document"
)
-func addFeedFlags(cmd *cobra.Command, verbose *bool, connections *int) {
- cmd.PersistentFlags().IntVarP(connections, "connections", "N", 8, "The number of connections to use")
- cmd.PersistentFlags().BoolVarP(verbose, "verbose", "v", false, "Verbose mode. Print errors as they happen")
+func addFeedFlags(cmd *cobra.Command, options *feedOptions) {
+ cmd.PersistentFlags().IntVar(&options.connections, "connections", 8, "The number of connections to use")
+ cmd.PersistentFlags().StringVar(&options.compression, "compression", "auto", `Compression mode to use. Default is "auto" which compresses large documents. Must be "auto", "gzip" or "none"`)
+ cmd.PersistentFlags().StringVar(&options.route, "route", "", "Target Vespa route for feed operations")
+ cmd.PersistentFlags().IntVar(&options.traceLevel, "trace", 0, "The trace level of network traffic. 0 to disable")
+ cmd.PersistentFlags().IntVar(&options.timeoutSecs, "timeout", 0, "Feed operation timeout in seconds. 0 to disable")
+ cmd.PersistentFlags().BoolVar(&options.verbose, "verbose", false, "Verbose mode. Print successful operations in addition to errors")
+ memprofile := "memprofile"
+ cpuprofile := "cpuprofile"
+ cmd.PersistentFlags().StringVar(&options.memprofile, memprofile, "", "Write a heap profile to given file")
+ cmd.PersistentFlags().StringVar(&options.cpuprofile, cpuprofile, "", "Write a CPU profile to given file")
+ // Hide these flags as they are intended for internal use
+ cmd.PersistentFlags().MarkHidden(memprofile)
+ cmd.PersistentFlags().MarkHidden(cpuprofile)
+}
+
+type feedOptions struct {
+ connections int
+ compression string
+ route string
+ verbose bool
+ traceLevel int
+ timeoutSecs int
+ memprofile string
+ cpuprofile string
}
func newFeedCmd(cli *CLI) *cobra.Command {
- var (
- verbose bool
- connections int
- )
+ var options feedOptions
cmd := &cobra.Command{
Use: "feed FILE",
Short: "Feed documents to a Vespa cluster",
@@ -56,10 +76,27 @@ $ cat documents.jsonl | vespa feed -
defer f.Close()
r = f
}
- return feed(r, cli, verbose, connections)
+ if options.cpuprofile != "" {
+ f, err := os.Create(options.cpuprofile)
+ if err != nil {
+ return err
+ }
+ pprof.StartCPUProfile(f)
+ defer pprof.StopCPUProfile()
+ }
+ err := feed(r, cli, options)
+ if options.memprofile != "" {
+ f, err := os.Create(options.memprofile)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ pprof.WriteHeapProfile(f)
+ }
+ return err
},
}
- addFeedFlags(cmd, &verbose, &connections)
+ addFeedFlags(cmd, &options)
return cmd
}
@@ -67,29 +104,46 @@ func createServiceClients(service *vespa.Service, n int) []util.HTTPClient {
clients := make([]util.HTTPClient, 0, n)
for i := 0; i < n; i++ {
client := service.Client().Clone()
- util.ForceHTTP2(client, service.TLSOptions.KeyPair) // Feeding should always use HTTP/2
+ // Feeding should always use HTTP/2
+ util.ForceHTTP2(client, service.TLSOptions.KeyPair, service.TLSOptions.CACertificate, service.TLSOptions.TrustAll)
clients = append(clients, client)
}
return clients
}
-func feed(r io.Reader, cli *CLI, verbose bool, connections int) error {
+func (opts feedOptions) compressionMode() (document.Compression, error) {
+ switch opts.compression {
+ case "auto":
+ return document.CompressionAuto, nil
+ case "none":
+ return document.CompressionNone, nil
+ case "gzip":
+ return document.CompressionGzip, nil
+ }
+ return 0, errHint(fmt.Errorf("invalid compression mode: %s", opts.compression), `Must be "auto", "gzip" or "none"`)
+}
+
+func feed(r io.Reader, cli *CLI, options feedOptions) error {
service, err := documentService(cli)
if err != nil {
return err
}
- clients := createServiceClients(service, connections)
+ clients := createServiceClients(service, options.connections)
+ compression, err := options.compressionMode()
+ if err != nil {
+ return err
+ }
client := document.NewClient(document.ClientOptions{
- BaseURL: service.BaseURL,
+ Compression: compression,
+ Timeout: time.Duration(options.timeoutSecs) * time.Second,
+ Route: options.route,
+ TraceLevel: options.traceLevel,
+ BaseURL: service.BaseURL,
}, clients)
- throttler := document.NewThrottler(connections)
+ throttler := document.NewThrottler(options.connections)
// TODO(mpolden): Make doom duration configurable
circuitBreaker := document.NewCircuitBreaker(10*time.Second, 0)
- errWriter := io.Discard
- if verbose {
- errWriter = cli.Stderr
- }
- dispatcher := document.NewDispatcher(client, throttler, circuitBreaker, errWriter)
+ dispatcher := document.NewDispatcher(client, throttler, circuitBreaker, cli.Stderr, options.verbose)
dec := document.NewDecoder(r)
start := cli.now()
diff --git a/client/go/internal/cli/cmd/root.go b/client/go/internal/cli/cmd/root.go
index 360af9d0dcf..88d43411983 100644
--- a/client/go/internal/cli/cmd/root.go
+++ b/client/go/internal/cli/cmd/root.go
@@ -2,7 +2,6 @@
package cmd
import (
- "crypto/tls"
"encoding/json"
"fmt"
"io"
@@ -88,18 +87,9 @@ func (c *execSubprocess) Run(name string, args ...string) ([]byte, error) {
return exec.Command(name, args...).Output()
}
-type ztsClient interface {
- AccessToken(domain string, certficiate tls.Certificate) (string, error)
-}
-
-type auth0Client interface {
- AccessToken() (string, error)
- HasCredentials() bool
-}
-
-type auth0Factory func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error)
+type auth0Factory func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error)
-type ztsFactory func(httpClient util.HTTPClient, url string) (ztsClient, error)
+type ztsFactory func(httpClient util.HTTPClient, domain, url string) (vespa.Authenticator, error)
// New creates the Vespa CLI, writing output to stdout and stderr, and reading environment variables from environment.
func New(stdout, stderr io.Writer, environment []string) (*CLI, error) {
@@ -143,11 +133,11 @@ For detailed description of flags and configuration, see 'vespa help config'.
httpClient: util.CreateClient(time.Second * 10),
exec: &execSubprocess{},
now: time.Now,
- auth0Factory: func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error) {
+ auth0Factory: func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) {
return auth0.NewClient(httpClient, options)
},
- ztsFactory: func(httpClient util.HTTPClient, url string) (ztsClient, error) {
- return zts.NewClient(httpClient, url)
+ ztsFactory: func(httpClient util.HTTPClient, domain, url string) (vespa.Authenticator, error) {
+ return zts.NewClient(httpClient, domain, url)
},
}
cli.isTerminal = func() bool { return isTerminal(cli.Stdout) && isTerminal(cli.Stderr) }
@@ -321,16 +311,34 @@ func (c *CLI) createTarget(opts targetOptions) (vespa.Target, error) {
if err != nil {
return nil, err
}
+ customURL := ""
if strings.HasPrefix(targetType, "http") {
- return vespa.CustomTarget(c.httpClient, targetType), nil
+ customURL = targetType
+ targetType = vespa.TargetCustom
}
switch targetType {
- case vespa.TargetLocal:
- return vespa.LocalTarget(c.httpClient), nil
+ case vespa.TargetLocal, vespa.TargetCustom:
+ return c.createCustomTarget(targetType, customURL)
case vespa.TargetCloud, vespa.TargetHosted:
return c.createCloudTarget(targetType, opts)
+ default:
+ return nil, errHint(fmt.Errorf("invalid target: %s", targetType), "Valid targets are 'local', 'cloud', 'hosted' or an URL")
+ }
+}
+
+func (c *CLI) createCustomTarget(targetType, customURL string) (vespa.Target, error) {
+ tlsOptions, err := c.config.readTLSOptions(vespa.DefaultApplication, targetType)
+ if err != nil {
+ return nil, err
+ }
+ switch targetType {
+ case vespa.TargetLocal:
+ return vespa.LocalTarget(c.httpClient, tlsOptions), nil
+ case vespa.TargetCustom:
+ return vespa.CustomTarget(c.httpClient, customURL, tlsOptions), nil
+ default:
+ return nil, fmt.Errorf("invalid custom target: %s", targetType)
}
- return nil, errHint(fmt.Errorf("invalid target: %s", targetType), "Valid targets are 'local', 'cloud', 'hosted' or an URL")
}
func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Target, error) {
@@ -347,48 +355,53 @@ func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Ta
return nil, err
}
var (
- apiKey []byte
- authConfigPath string
+ apiAuth vespa.Authenticator
+ deploymentAuth vespa.Authenticator
apiTLSOptions vespa.TLSOptions
deploymentTLSOptions vespa.TLSOptions
)
switch targetType {
case vespa.TargetCloud:
- apiKey, err = c.config.readAPIKey(c, system, deployment.Application.Tenant)
+ apiKey, err := c.config.readAPIKey(c, system, deployment.Application.Tenant)
if err != nil {
return nil, err
}
- authConfigPath = c.config.authConfigPath()
+ if apiKey == nil {
+ authConfigPath := c.config.authConfigPath()
+ auth0, err := c.auth0Factory(c.httpClient, auth0.Options{ConfigPath: authConfigPath, SystemName: system.Name, SystemURL: system.URL})
+ if err != nil {
+ return nil, err
+ }
+ apiAuth = auth0
+ } else {
+ apiAuth = vespa.NewRequestSigner(deployment.Application.SerializedForm(), apiKey)
+ }
deploymentTLSOptions = vespa.TLSOptions{}
if !opts.noCertificate {
- kp, err := c.config.x509KeyPair(deployment.Application, targetType)
+ kp, err := c.config.readTLSOptions(deployment.Application, targetType)
if err != nil {
- return nil, errHint(err, "Deployment to cloud requires a certificate. Try 'vespa auth cert'")
- }
- deploymentTLSOptions = vespa.TLSOptions{
- KeyPair: []tls.Certificate{kp.KeyPair},
- CertificateFile: kp.CertificateFile,
- PrivateKeyFile: kp.PrivateKeyFile,
+ return nil, errHint(err, "Deployment to cloud requires a certificate", "Try 'vespa auth cert' to create a self-signed certificate")
}
+ deploymentTLSOptions = kp
}
case vespa.TargetHosted:
- kp, err := c.config.x509KeyPair(deployment.Application, targetType)
+ kp, err := c.config.readTLSOptions(deployment.Application, targetType)
if err != nil {
return nil, errHint(err, "Deployment to hosted requires an Athenz certificate", "Try renewing certificate with 'athenz-user-cert'")
}
- apiTLSOptions = vespa.TLSOptions{
- KeyPair: []tls.Certificate{kp.KeyPair},
- CertificateFile: kp.CertificateFile,
- PrivateKeyFile: kp.PrivateKeyFile,
+ zts, err := c.ztsFactory(c.httpClient, system.AthenzDomain, zts.DefaultURL)
+ if err != nil {
+ return nil, err
}
- deploymentTLSOptions = apiTLSOptions
+ deploymentAuth = zts
+ apiTLSOptions = kp
+ deploymentTLSOptions = kp
default:
return nil, fmt.Errorf("invalid cloud target: %s", targetType)
}
apiOptions := vespa.APIOptions{
System: system,
TLSOptions: apiTLSOptions,
- APIKey: apiKey,
}
deploymentOptions := vespa.CloudDeploymentOptions{
Deployment: deployment,
@@ -403,15 +416,7 @@ func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Ta
Writer: c.Stdout,
Level: vespa.LogLevel(logLevel),
}
- auth0, err := c.auth0Factory(c.httpClient, auth0.Options{ConfigPath: authConfigPath, SystemName: apiOptions.System.Name, SystemURL: apiOptions.System.URL})
- if err != nil {
- return nil, err
- }
- zts, err := c.ztsFactory(c.httpClient, zts.DefaultURL)
- if err != nil {
- return nil, err
- }
- return vespa.CloudTarget(c.httpClient, zts, auth0, apiOptions, deploymentOptions, logOptions)
+ return vespa.CloudTarget(c.httpClient, apiAuth, deploymentAuth, apiOptions, deploymentOptions, logOptions)
}
// system returns the appropiate system for the target configured in this CLI.
@@ -460,7 +465,6 @@ func (c *CLI) createDeploymentOptions(pkg vespa.ApplicationPackage, target vespa
ApplicationPackage: pkg,
Target: target,
Timeout: timeout,
- HTTPClient: c.httpClient,
}, nil
}
diff --git a/client/go/internal/cli/cmd/test.go b/client/go/internal/cli/cmd/test.go
index 05633b1135e..8c4501e2870 100644
--- a/client/go/internal/cli/cmd/test.go
+++ b/client/go/internal/cli/cmd/test.go
@@ -263,7 +263,7 @@ func verify(step step, defaultCluster string, defaultParameters map[string]strin
var response *http.Response
if externalEndpoint {
- util.SetCertificates(context.cli.httpClient, []tls.Certificate{})
+ util.ConfigureTLS(context.cli.httpClient, []tls.Certificate{}, nil, false)
response, err = context.cli.httpClient.Do(request, 60*time.Second)
} else {
response, err = service.Do(request, 600*time.Second) // Vespa should provide a response within the given request timeout
diff --git a/client/go/internal/cli/cmd/testutil_test.go b/client/go/internal/cli/cmd/testutil_test.go
index 61f8dab2264..492e40d8855 100644
--- a/client/go/internal/cli/cmd/testutil_test.go
+++ b/client/go/internal/cli/cmd/testutil_test.go
@@ -3,13 +3,14 @@ package cmd
import (
"bytes"
- "crypto/tls"
+ "net/http"
"path/filepath"
"testing"
"github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0"
"github.com/vespa-engine/vespa/client/go/internal/mock"
"github.com/vespa-engine/vespa/client/go/internal/util"
+ "github.com/vespa-engine/vespa/client/go/internal/vespa"
)
func newTestCLI(t *testing.T, envVars ...string) (*CLI, *bytes.Buffer, *bytes.Buffer) {
@@ -29,21 +30,15 @@ func newTestCLI(t *testing.T, envVars ...string) (*CLI, *bytes.Buffer, *bytes.Bu
httpClient := &mock.HTTPClient{}
cli.httpClient = httpClient
cli.exec = &mock.Exec{}
- cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error) {
- return &mockAuth0{}, nil
+ cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) {
+ return &mockAuthenticator{}, nil
}
- cli.ztsFactory = func(httpClient util.HTTPClient, url string) (ztsClient, error) {
- return &mockZTS{}, nil
+ cli.ztsFactory = func(httpClient util.HTTPClient, domain, url string) (vespa.Authenticator, error) {
+ return &mockAuthenticator{}, nil
}
return cli, &stdout, &stderr
}
-type mockZTS struct{}
+type mockAuthenticator struct{}
-func (z *mockZTS) AccessToken(domain string, cert tls.Certificate) (string, error) { return "", nil }
-
-type mockAuth0 struct{ hasCredentials bool }
-
-func (a *mockAuth0) AccessToken() (string, error) { return "", nil }
-
-func (a *mockAuth0) HasCredentials() bool { return a.hasCredentials }
+func (a *mockAuthenticator) Authenticate(request *http.Request) error { return nil }
diff --git a/client/go/internal/util/http.go b/client/go/internal/util/http.go
index dcf05ed3a14..8a67b24dffb 100644
--- a/client/go/internal/util/http.go
+++ b/client/go/internal/util/http.go
@@ -4,6 +4,7 @@ package util
import (
"context"
"crypto/tls"
+ "crypto/x509"
"fmt"
"net"
"net/http"
@@ -35,7 +36,7 @@ func (c *defaultHTTPClient) Do(request *http.Request, timeout time.Duration) (re
func (c *defaultHTTPClient) Clone() HTTPClient { return CreateClient(c.client.Timeout) }
-func SetCertificates(client HTTPClient, certificates []tls.Certificate) {
+func ConfigureTLS(client HTTPClient, certificates []tls.Certificate, caCertificate []byte, trustAll bool) {
c, ok := client.(*defaultHTTPClient)
if !ok {
return
@@ -43,8 +44,14 @@ func SetCertificates(client HTTPClient, certificates []tls.Certificate) {
var tlsConfig *tls.Config = nil
if certificates != nil {
tlsConfig = &tls.Config{
- Certificates: certificates,
- MinVersion: tls.VersionTLS12,
+ Certificates: certificates,
+ MinVersion: tls.VersionTLS12,
+ InsecureSkipVerify: trustAll,
+ }
+ if caCertificate != nil {
+ certs := x509.NewCertPool()
+ certs.AppendCertsFromPEM(caCertificate)
+ tlsConfig.RootCAs = certs
}
}
if tr, ok := c.client.Transport.(*http.Transport); ok {
@@ -56,19 +63,13 @@ func SetCertificates(client HTTPClient, certificates []tls.Certificate) {
}
}
-func ForceHTTP2(client HTTPClient, certificates []tls.Certificate) {
+func ForceHTTP2(client HTTPClient, certificates []tls.Certificate, caCertificate []byte, trustAll bool) {
c, ok := client.(*defaultHTTPClient)
if !ok {
return
}
- var tlsConfig *tls.Config = nil
var dialFunc func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error)
- if certificates != nil {
- tlsConfig = &tls.Config{
- Certificates: certificates,
- MinVersion: tls.VersionTLS12,
- }
- } else {
+ if certificates == nil {
// No certificate, so force H2C (HTTP/2 over clear-text) by using a non-TLS Dialer
dialer := net.Dialer{}
dialFunc = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
@@ -80,10 +81,10 @@ func ForceHTTP2(client HTTPClient, certificates []tls.Certificate) {
// https://github.com/golang/go/issues/16582
// https://github.com/golang/go/issues/22091
c.client.Transport = &http2.Transport{
- AllowHTTP: true,
- TLSClientConfig: tlsConfig,
- DialTLSContext: dialFunc,
+ AllowHTTP: true,
+ DialTLSContext: dialFunc,
}
+ ConfigureTLS(client, certificates, caCertificate, trustAll)
}
func CreateClient(timeout time.Duration) HTTPClient {
diff --git a/client/go/internal/vespa/crypto.go b/client/go/internal/vespa/crypto.go
index 9621d0c1180..5e273538869 100644
--- a/client/go/internal/vespa/crypto.go
+++ b/client/go/internal/vespa/crypto.go
@@ -111,6 +111,8 @@ func NewRequestSigner(keyID string, pemPrivateKey []byte) *RequestSigner {
}
}
+func (rs *RequestSigner) Authenticate(request *http.Request) error { return rs.SignRequest(request) }
+
// SignRequest signs the given HTTP request using the private key in rs
func (rs *RequestSigner) SignRequest(request *http.Request) error {
timestamp := rs.now().UTC().Format(time.RFC3339)
diff --git a/client/go/internal/vespa/deploy.go b/client/go/internal/vespa/deploy.go
index 687bfc46124..f633c8ed9ee 100644
--- a/client/go/internal/vespa/deploy.go
+++ b/client/go/internal/vespa/deploy.go
@@ -45,7 +45,6 @@ type DeploymentOptions struct {
ApplicationPackage ApplicationPackage
Timeout time.Duration
Version version.Version
- HTTPClient util.HTTPClient
}
type LogLinePrepareResponse struct {
@@ -130,7 +129,7 @@ func Prepare(deployment DeploymentOptions) (PrepareResult, error) {
return PrepareResult{}, err
}
serviceDescription := "Deploy service"
- response, err := deployment.HTTPClient.Do(req, time.Second*30)
+ response, err := deployServiceDo(req, time.Second*30, deployment)
if err != nil {
return PrepareResult{}, err
}
@@ -171,7 +170,7 @@ func Activate(sessionID int64, deployment DeploymentOptions) error {
return err
}
serviceDescription := "Deploy service"
- response, err := deployment.HTTPClient.Do(req, time.Second*30)
+ response, err := deployServiceDo(req, time.Second*30, deployment)
if err != nil {
return err
}
@@ -263,11 +262,7 @@ func Submit(opts DeploymentOptions) error {
}
request.Header.Set("Content-Type", writer.FormDataContentType())
serviceDescription := "Submit service"
- sigKeyId := opts.Target.Deployment().Application.SerializedForm()
- if err := opts.Target.SignRequest(request, sigKeyId); err != nil {
- return fmt.Errorf("failed to sign api request: %w", err)
- }
- response, err := opts.HTTPClient.Do(request, time.Minute*10)
+ response, err := deployServiceDo(request, time.Minute*10, opts)
if err != nil {
return err
}
@@ -275,6 +270,14 @@ func Submit(opts DeploymentOptions) error {
return checkResponse(request, response, serviceDescription)
}
+func deployServiceDo(request *http.Request, timeout time.Duration, opts DeploymentOptions) (*http.Response, error) {
+ s, err := opts.Target.Service(DeployService, 0, 0, "")
+ if err != nil {
+ return nil, err
+ }
+ return s.Do(request, timeout)
+}
+
func checkDeploymentOpts(opts DeploymentOptions) error {
if opts.Target.Type() == TargetCloud && !opts.ApplicationPackage.HasCertificate() {
return fmt.Errorf("%s: missing certificate in package", opts)
@@ -334,11 +337,6 @@ func uploadApplicationPackage(url *url.URL, opts DeploymentOptions) (PrepareResu
if err != nil {
return PrepareResult{}, err
}
-
- keyID := opts.Target.Deployment().Application.SerializedForm()
- if err := opts.Target.SignRequest(request, keyID); err != nil {
- return PrepareResult{}, err
- }
response, err := service.Do(request, time.Minute*10)
if err != nil {
return PrepareResult{}, err
diff --git a/client/go/internal/vespa/deploy_test.go b/client/go/internal/vespa/deploy_test.go
index 3e74e9ab3b6..da2604282c0 100644
--- a/client/go/internal/vespa/deploy_test.go
+++ b/client/go/internal/vespa/deploy_test.go
@@ -19,12 +19,11 @@ import (
func TestDeploy(t *testing.T) {
httpClient := mock.HTTPClient{}
- target := LocalTarget(&httpClient)
+ target := LocalTarget(&httpClient, TLSOptions{})
appDir, _ := mock.ApplicationPackageDir(t, false, false)
opts := DeploymentOptions{
Target: target,
ApplicationPackage: ApplicationPackage{Path: appDir},
- HTTPClient: &httpClient,
}
_, err := Deploy(opts)
assert.Nil(t, err)
@@ -47,7 +46,6 @@ func TestDeployCloud(t *testing.T) {
opts := DeploymentOptions{
Target: target,
ApplicationPackage: ApplicationPackage{Path: appDir},
- HTTPClient: &httpClient,
}
_, err := Deploy(opts)
require.Nil(t, err)
diff --git a/client/go/internal/vespa/document/dispatcher.go b/client/go/internal/vespa/document/dispatcher.go
index 838a7bc45ee..5c99f3bf056 100644
--- a/client/go/internal/vespa/document/dispatcher.go
+++ b/client/go/internal/vespa/document/dispatcher.go
@@ -4,6 +4,7 @@ import (
"container/list"
"fmt"
"io"
+ "strings"
"sync"
"sync/atomic"
"time"
@@ -18,15 +19,19 @@ type Dispatcher struct {
circuitBreaker CircuitBreaker
stats Stats
- started bool
- ready chan Id
- results chan Result
+ started bool
+ ready chan Id
+ results chan Result
+ msgs chan string
+
inflight map[string]*documentGroup
inflightCount int64
- errWriter io.Writer
+ output io.Writer
+ verbose bool
+ listPool sync.Pool
mu sync.RWMutex
- wg sync.WaitGroup
+ workerWg sync.WaitGroup
resultWg sync.WaitGroup
}
@@ -38,30 +43,24 @@ type documentOp struct {
// documentGroup holds document operations which share an ID, and must be dispatched in order.
type documentGroup struct {
- ops *list.List
- mu sync.Mutex
+ q *Queue[documentOp]
+ mu sync.Mutex
}
func (g *documentGroup) add(op documentOp, first bool) {
g.mu.Lock()
defer g.mu.Unlock()
- if g.ops == nil {
- g.ops = list.New()
- }
- if first {
- g.ops.PushFront(op)
- } else {
- g.ops.PushBack(op)
- }
+ g.q.Add(op, first)
}
-func NewDispatcher(feeder Feeder, throttler Throttler, breaker CircuitBreaker, errWriter io.Writer) *Dispatcher {
+func NewDispatcher(feeder Feeder, throttler Throttler, breaker CircuitBreaker, output io.Writer, verbose bool) *Dispatcher {
d := &Dispatcher{
feeder: feeder,
throttler: throttler,
circuitBreaker: breaker,
inflight: make(map[string]*documentGroup),
- errWriter: errWriter,
+ output: output,
+ verbose: verbose,
}
d.start()
return d
@@ -69,16 +68,15 @@ func NewDispatcher(feeder Feeder, throttler Throttler, breaker CircuitBreaker, e
func (d *Dispatcher) sendDocumentIn(group *documentGroup) {
group.mu.Lock()
- defer group.mu.Unlock()
- defer d.releaseSlot()
- first := group.ops.Front()
- if first == nil {
+ op, ok := group.q.Poll()
+ if !ok {
panic("sending from empty document group, this should not happen")
}
- op := group.ops.Remove(first).(documentOp)
op.attempts++
result := d.feeder.Send(op.document)
d.results <- result
+ d.releaseSlot()
+ group.mu.Unlock()
if d.shouldRetry(op, result) {
d.enqueue(op)
}
@@ -86,29 +84,35 @@ func (d *Dispatcher) sendDocumentIn(group *documentGroup) {
func (d *Dispatcher) shouldRetry(op documentOp, result Result) bool {
if result.HTTPStatus/100 == 2 || result.HTTPStatus == 404 || result.HTTPStatus == 412 {
+ if d.verbose {
+ d.msgs <- fmt.Sprintf("feed: successfully fed %s with status %d", op.document.Id, result.HTTPStatus)
+ }
d.throttler.Success()
d.circuitBreaker.Success()
return false
}
if result.HTTPStatus == 429 || result.HTTPStatus == 503 {
- fmt.Fprintf(d.errWriter, "feed: %s was throttled with status %d: retrying\n", op.document, result.HTTPStatus)
+ d.msgs <- fmt.Sprintf("feed: %s was throttled with status %d: retrying\n", op.document, result.HTTPStatus)
d.throttler.Throttled(atomic.LoadInt64(&d.inflightCount))
return true
}
if result.Err != nil || result.HTTPStatus == 500 || result.HTTPStatus == 502 || result.HTTPStatus == 504 {
retry := op.attempts <= maxAttempts
- msg := "feed: " + op.document.String() + " failed with "
+ var msg strings.Builder
+ msg.WriteString("feed: ")
+ msg.WriteString(op.document.String())
if result.Err != nil {
- msg += "error " + result.Err.Error()
+ msg.WriteString("error ")
+ msg.WriteString(result.Err.Error())
} else {
- msg += fmt.Sprintf("status %d", result.HTTPStatus)
+ msg.WriteString(fmt.Sprintf("status %d", result.HTTPStatus))
}
if retry {
- msg += ": retrying"
+ msg.WriteString(": retrying")
} else {
- msg += fmt.Sprintf(": giving up after %d attempts", maxAttempts)
+ msg.WriteString(fmt.Sprintf(": giving up after %d attempts", maxAttempts))
}
- fmt.Fprintln(d.errWriter, msg)
+ d.msgs <- msg.String()
d.circuitBreaker.Error(fmt.Errorf("request failed with status %d", result.HTTPStatus))
if retry {
return true
@@ -123,37 +127,27 @@ func (d *Dispatcher) start() {
if d.started {
return
}
+ d.listPool.New = func() any { return list.New() }
d.ready = make(chan Id, 4096)
d.results = make(chan Result, 4096)
+ d.msgs = make(chan string, 4096)
d.started = true
- d.wg.Add(1)
- go func() {
- defer d.wg.Done()
- d.readDocuments()
- }()
- d.resultWg.Add(1)
- go func() {
- defer d.resultWg.Done()
- d.readResults()
- }()
+ d.resultWg.Add(2)
+ go d.sumStats()
+ go d.printMessages()
}
-func (d *Dispatcher) readDocuments() {
- for id := range d.ready {
- d.mu.RLock()
- group := d.inflight[id.String()]
- d.mu.RUnlock()
- d.wg.Add(1)
- go func() {
- defer d.wg.Done()
- d.sendDocumentIn(group)
- }()
+func (d *Dispatcher) sumStats() {
+ defer d.resultWg.Done()
+ for result := range d.results {
+ d.stats.Add(result.Stats)
}
}
-func (d *Dispatcher) readResults() {
- for result := range d.results {
- d.stats.Add(result.Stats)
+func (d *Dispatcher) printMessages() {
+ defer d.resultWg.Done()
+ for msg := range d.msgs {
+ fmt.Fprintln(d.output, msg)
}
}
@@ -162,10 +156,11 @@ func (d *Dispatcher) enqueue(op documentOp) error {
if !d.started {
return fmt.Errorf("dispatcher is closed")
}
- group, ok := d.inflight[op.document.Id.String()]
+ key := op.document.Id.String()
+ group, ok := d.inflight[key]
if !ok {
- group = &documentGroup{}
- d.inflight[op.document.Id.String()] = group
+ group = &documentGroup{q: NewQueue[documentOp](&d.listPool)}
+ d.inflight[key] = group
}
d.mu.Unlock()
group.add(op, op.attempts > 0)
@@ -177,6 +172,19 @@ func (d *Dispatcher) enqueueWithSlot(id Id) {
d.acquireSlot()
d.ready <- id
d.throttler.Sent()
+ d.dispatch()
+}
+
+func (d *Dispatcher) dispatch() {
+ d.workerWg.Add(1)
+ go func() {
+ defer d.workerWg.Done()
+ id := <-d.ready
+ d.mu.RLock()
+ group := d.inflight[id.String()]
+ d.mu.RUnlock()
+ d.sendDocumentIn(group)
+ }()
}
func (d *Dispatcher) acquireSlot() {
@@ -188,25 +196,20 @@ func (d *Dispatcher) acquireSlot() {
func (d *Dispatcher) releaseSlot() { atomic.AddInt64(&d.inflightCount, -1) }
-func closeAndWait[T any](ch chan T, wg *sync.WaitGroup, d *Dispatcher, markClosed bool) {
- d.mu.Lock()
- if d.started {
- close(ch)
- if markClosed {
- d.started = false
- }
- }
- d.mu.Unlock()
- wg.Wait()
-}
-
func (d *Dispatcher) Enqueue(doc Document) error { return d.enqueue(documentOp{document: doc}) }
func (d *Dispatcher) Stats() Stats { return d.stats }
// Close closes the dispatcher and waits for all inflight operations to complete.
func (d *Dispatcher) Close() error {
- closeAndWait(d.ready, &d.wg, d, false)
- closeAndWait(d.results, &d.resultWg, d, true)
+ d.workerWg.Wait() // Wait for all inflight operations to complete
+ d.mu.Lock()
+ if d.started {
+ close(d.results)
+ close(d.msgs)
+ d.started = false
+ }
+ d.mu.Unlock()
+ d.resultWg.Wait() // Wait for results
return nil
}
diff --git a/client/go/internal/vespa/document/dispatcher_test.go b/client/go/internal/vespa/document/dispatcher_test.go
index 80bc5f603ae..d066f5bc9ae 100644
--- a/client/go/internal/vespa/document/dispatcher_test.go
+++ b/client/go/internal/vespa/document/dispatcher_test.go
@@ -41,7 +41,7 @@ func TestDispatcher(t *testing.T) {
clock := &manualClock{tick: time.Second}
throttler := newThrottler(8, clock.now)
breaker := NewCircuitBreaker(time.Second, 0)
- dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard)
+ dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard, false)
docs := []Document{
{Id: mustParseId("id:ns:type::doc1"), Operation: OperationPut, Body: []byte(`{"fields":{"foo": "123"}}`)},
{Id: mustParseId("id:ns:type::doc2"), Operation: OperationPut, Body: []byte(`{"fields":{"bar": "456"}}`)},
@@ -74,7 +74,7 @@ func TestDispatcherOrdering(t *testing.T) {
clock := &manualClock{tick: time.Second}
throttler := newThrottler(8, clock.now)
breaker := NewCircuitBreaker(time.Second, 0)
- dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard)
+ dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard, false)
for _, d := range docs {
dispatcher.Enqueue(d)
}
@@ -110,7 +110,7 @@ func TestDispatcherOrderingWithFailures(t *testing.T) {
clock := &manualClock{tick: time.Second}
throttler := newThrottler(8, clock.now)
breaker := NewCircuitBreaker(time.Second, 0)
- dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard)
+ dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard, false)
for _, d := range docs {
dispatcher.Enqueue(d)
}
diff --git a/client/go/internal/vespa/document/document.go b/client/go/internal/vespa/document/document.go
index efb60ad8c0a..214d1dc4797 100644
--- a/client/go/internal/vespa/document/document.go
+++ b/client/go/internal/vespa/document/document.go
@@ -14,13 +14,15 @@ var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
type Operation int
const (
- OperationPut = iota
+ OperationPut Operation = iota
OperationUpdate
OperationRemove
)
// Id represents a Vespa document ID.
type Id struct {
+ id string
+
Type string
Namespace string
Number *int64
@@ -36,24 +38,7 @@ func (d Id) Equal(o Id) bool {
d.UserSpecific == o.UserSpecific
}
-func (d Id) String() string {
- var sb strings.Builder
- sb.WriteString("id:")
- sb.WriteString(d.Namespace)
- sb.WriteString(":")
- sb.WriteString(d.Type)
- sb.WriteString(":")
- if d.Number != nil {
- sb.WriteString("n=")
- sb.WriteString(strconv.FormatInt(*d.Number, 10))
- } else if d.Group != "" {
- sb.WriteString("g=")
- sb.WriteString(d.Group)
- }
- sb.WriteString(":")
- sb.WriteString(d.UserSpecific)
- return sb.String()
-}
+func (d Id) String() string { return d.id }
// ParseId parses a serialized document ID string.
func ParseId(serialized string) (Id, error) {
@@ -95,6 +80,7 @@ func ParseId(serialized string) (Id, error) {
return Id{}, parseError(serialized)
}
return Id{
+ id: serialized,
Namespace: namespace,
Type: docType,
Number: number,
diff --git a/client/go/internal/vespa/document/http.go b/client/go/internal/vespa/document/http.go
index 588330a0574..51b6fa4de39 100644
--- a/client/go/internal/vespa/document/http.go
+++ b/client/go/internal/vespa/document/http.go
@@ -2,6 +2,7 @@ package document
import (
"bytes"
+ "compress/gzip"
"encoding/json"
"fmt"
"io"
@@ -16,6 +17,14 @@ import (
"github.com/vespa-engine/vespa/client/go/internal/util"
)
+type Compression int
+
+const (
+ CompressionAuto Compression = iota
+ CompressionNone
+ CompressionGzip
+)
+
// Client represents a HTTP client for the /document/v1/ API.
type Client struct {
options ClientOptions
@@ -26,10 +35,11 @@ type Client struct {
// ClientOptions specifices the configuration options of a feed client.
type ClientOptions struct {
- BaseURL string
- Timeout time.Duration
- Route string
- TraceLevel *int
+ BaseURL string
+ Timeout time.Duration
+ Route string
+ TraceLevel int
+ Compression Compression
}
type countingHTTPClient struct {
@@ -72,14 +82,18 @@ func NewClient(options ClientOptions, httpClients []util.HTTPClient) *Client {
func (c *Client) queryParams() url.Values {
params := url.Values{}
- if c.options.Timeout > 0 {
- params.Set("timeout", strconv.FormatInt(c.options.Timeout.Milliseconds(), 10)+"ms")
+ timeout := c.options.Timeout
+ if timeout == 0 {
+ timeout = 200 * time.Second
+ } else {
+ timeout = timeout*11/10 + 1000
}
+ params.Set("timeout", strconv.FormatInt(timeout.Milliseconds(), 10)+"ms")
if c.options.Route != "" {
params.Set("route", c.options.Route)
}
- if c.options.TraceLevel != nil {
- params.Set("tracelevel", strconv.Itoa(*c.options.TraceLevel))
+ if c.options.TraceLevel > 0 {
+ params.Set("tracelevel", strconv.Itoa(c.options.TraceLevel))
}
return params
}
@@ -148,6 +162,33 @@ func (c *Client) leastBusyClient() *countingHTTPClient {
return &leastBusy
}
+func (c *Client) createRequest(method, url string, body []byte) (*http.Request, error) {
+ var r io.Reader
+ useGzip := c.options.Compression == CompressionGzip || (c.options.Compression == CompressionAuto && len(body) > 512)
+ if useGzip {
+ var buf bytes.Buffer
+ w := gzip.NewWriter(&buf)
+ if _, err := w.Write(body); err != nil {
+ return nil, err
+ }
+ if err := w.Close(); err != nil {
+ return nil, err
+ }
+ r = &buf
+ } else {
+ r = bytes.NewReader(body)
+ }
+ req, err := http.NewRequest(method, url, r)
+ if err != nil {
+ return nil, err
+ }
+ if useGzip {
+ req.Header.Set("Content-Encoding", "gzip")
+ }
+ req.Header.Set("Content-Type", "application/json; charset=utf-8")
+ return req, nil
+}
+
// Send given document to the endpoint configured in this client.
func (c *Client) Send(document Document) Result {
start := c.now()
@@ -156,7 +197,7 @@ func (c *Client) Send(document Document) Result {
if err != nil {
return resultWithErr(result, err)
}
- req, err := http.NewRequest(method, url.String(), bytes.NewReader(document.Body))
+ req, err := c.createRequest(method, url.String(), document.Body)
if err != nil {
return resultWithErr(result, err)
}
@@ -166,7 +207,7 @@ func (c *Client) Send(document Document) Result {
}
defer resp.Body.Close()
elapsed := c.now().Sub(start)
- return c.resultWithResponse(resp, result, document, elapsed)
+ return resultWithResponse(resp, result, document, elapsed)
}
func resultWithErr(result Result, err error) Result {
@@ -176,7 +217,7 @@ func resultWithErr(result Result, err error) Result {
return result
}
-func (c *Client) resultWithResponse(resp *http.Response, result Result, document Document, elapsed time.Duration) Result {
+func resultWithResponse(resp *http.Response, result Result, document Document, elapsed time.Duration) Result {
result.HTTPStatus = resp.StatusCode
result.Stats.Responses++
result.Stats.ResponsesByCode = map[int]int64{resp.StatusCode: 1}
diff --git a/client/go/internal/vespa/document/http_test.go b/client/go/internal/vespa/document/http_test.go
index 43eaf1bfdf9..314113c53be 100644
--- a/client/go/internal/vespa/document/http_test.go
+++ b/client/go/internal/vespa/document/http_test.go
@@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"reflect"
+ "strings"
"testing"
"time"
@@ -108,7 +109,7 @@ func TestClientSend(t *testing.T) {
if r.Method != http.MethodPut {
t.Errorf("got r.Method = %q, want %q", r.Method, http.MethodPut)
}
- wantURL := fmt.Sprintf("https://example.com:1337/document/v1/ns/type/docid/%s?create=true&timeout=5000ms", doc.Id.UserSpecific)
+ wantURL := fmt.Sprintf("https://example.com:1337/document/v1/ns/type/docid/%s?create=true&timeout=5500ms", doc.Id.UserSpecific)
if r.URL.String() != wantURL {
t.Errorf("got r.URL = %q, want %q", r.URL, wantURL)
}
@@ -141,6 +142,55 @@ func TestClientSend(t *testing.T) {
}
}
+func TestClientSendCompressed(t *testing.T) {
+ httpClient := mock.HTTPClient{}
+ client := NewClient(ClientOptions{
+ BaseURL: "https://example.com:1337",
+ Timeout: time.Duration(5 * time.Second),
+ }, []util.HTTPClient{&httpClient})
+
+ bigBody := fmt.Sprintf(`{"fields":{"foo": "%s"}}`, strings.Repeat("s", 512+1))
+ bigDoc := Document{Create: true, Id: mustParseId("id:ns:type::doc1"), Operation: OperationUpdate, Body: []byte(bigBody)}
+ smallDoc := Document{Create: true, Id: mustParseId("id:ns:type::doc2"), Operation: OperationUpdate, Body: []byte(`{"fields":{"foo": "s"}}`)}
+
+ client.options.Compression = CompressionNone
+ _ = client.Send(bigDoc)
+ assertCompressedRequest(t, false, httpClient.LastRequest)
+ _ = client.Send(smallDoc)
+ assertCompressedRequest(t, false, httpClient.LastRequest)
+
+ client.options.Compression = CompressionAuto
+ _ = client.Send(bigDoc)
+ assertCompressedRequest(t, true, httpClient.LastRequest)
+ _ = client.Send(smallDoc)
+ assertCompressedRequest(t, false, httpClient.LastRequest)
+
+ client.options.Compression = CompressionGzip
+ _ = client.Send(bigDoc)
+ assertCompressedRequest(t, true, httpClient.LastRequest)
+ _ = client.Send(smallDoc)
+ assertCompressedRequest(t, true, httpClient.LastRequest)
+}
+
+func assertCompressedRequest(t *testing.T, want bool, request *http.Request) {
+ wantEnc := ""
+ if want {
+ wantEnc = "gzip"
+ }
+ gotEnc := request.Header.Get("Content-Encoding")
+ if gotEnc != wantEnc {
+ t.Errorf("got Content-Encoding=%q, want %q", gotEnc, wantEnc)
+ }
+ body, err := io.ReadAll(request.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ compressed := bytes.HasPrefix(body, []byte{0x1f, 0x8b})
+ if compressed != want {
+ t.Errorf("got compressed=%t, want %t", compressed, want)
+ }
+}
+
func TestURLPath(t *testing.T) {
tests := []struct {
in Id
diff --git a/client/go/internal/vespa/document/queue.go b/client/go/internal/vespa/document/queue.go
new file mode 100644
index 00000000000..2e5a1976d58
--- /dev/null
+++ b/client/go/internal/vespa/document/queue.go
@@ -0,0 +1,43 @@
+package document
+
+import (
+ "container/list"
+ "sync"
+)
+
+// Queue wraps a doubly linked list. It attempts to re-use lists through a sync.Pool to reduce GC pressure.
+type Queue[T any] struct {
+ items *list.List
+ listPool *sync.Pool
+}
+
+func NewQueue[T any](listPool *sync.Pool) *Queue[T] {
+ if listPool.New == nil {
+ listPool.New = func() any { return list.New() }
+ }
+ return &Queue[T]{listPool: listPool}
+}
+
+func (q *Queue[T]) Add(item T, front bool) {
+ if q.items == nil {
+ q.items = q.listPool.Get().(*list.List)
+ }
+ if front {
+ q.items.PushFront(item)
+ } else {
+ q.items.PushBack(item)
+ }
+}
+
+func (q *Queue[T]) Poll() (T, bool) {
+ if q.items == nil || q.items.Front() == nil {
+ var empty T
+ return empty, false
+ }
+ item := q.items.Remove(q.items.Front()).(T)
+ if q.items.Front() == nil { // Emptied queue, release list back to pool
+ q.listPool.Put(q.items)
+ q.items = nil
+ }
+ return item, true
+}
diff --git a/client/go/internal/vespa/document/queue_test.go b/client/go/internal/vespa/document/queue_test.go
new file mode 100644
index 00000000000..992e7410053
--- /dev/null
+++ b/client/go/internal/vespa/document/queue_test.go
@@ -0,0 +1,29 @@
+package document
+
+import (
+ "sync"
+ "testing"
+)
+
+func TestQueue(t *testing.T) {
+ q := NewQueue[int](&sync.Pool{})
+ assertPoll(t, q, 0, false)
+ q.Add(1, false)
+ q.Add(2, false)
+ assertPoll(t, q, 1, true)
+ assertPoll(t, q, 2, true)
+ q.Add(3, false)
+ q.Add(4, true)
+ assertPoll(t, q, 4, true)
+ assertPoll(t, q, 3, true)
+}
+
+func assertPoll(t *testing.T, q *Queue[int], want int, wantOk bool) {
+ got, ok := q.Poll()
+ if ok != wantOk {
+ t.Fatalf("got ok=%t, want %t", ok, wantOk)
+ }
+ if got != want {
+ t.Fatalf("got v=%d, want %d", got, want)
+ }
+}
diff --git a/client/go/internal/vespa/target.go b/client/go/internal/vespa/target.go
index bc936623bcb..9f3fd7f5c65 100644
--- a/client/go/internal/vespa/target.go
+++ b/client/go/internal/vespa/target.go
@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
+ "sync"
"time"
"github.com/vespa-engine/vespa/client/go/internal/util"
@@ -17,7 +18,7 @@ const (
// A target for a local Vespa service
TargetLocal = "local"
- // A target for a custom URL
+ // A target for a Vespa service at a custom URL
TargetCustom = "custom"
// A Vespa Cloud target
@@ -38,13 +39,19 @@ const (
retryInterval = 2 * time.Second
)
+// Authenticator authenticates the given HTTP request.
+type Authenticator interface {
+ Authenticate(request *http.Request) error
+}
+
// Service represents a Vespa service.
type Service struct {
BaseURL string
Name string
TLSOptions TLSOptions
- zts zts
+ once sync.Once
+ auth Authenticator
httpClient util.HTTPClient
}
@@ -65,19 +72,19 @@ type Target interface {
// PrintLog writes the logs of this deployment using given options to control output.
PrintLog(options LogOptions) error
- // SignRequest signs request with given keyID as required by the implementation of this target.
- SignRequest(request *http.Request, keyID string) error
-
// CheckVersion verifies whether clientVersion is compatible with this target.
CheckVersion(clientVersion version.Version) error
}
-// TLSOptions configures the client certificate to use for cloud API or service requests.
+// TLSOptions holds the client certificate to use for cloud API or service requests.
type TLSOptions struct {
- KeyPair []tls.Certificate
- CertificateFile string
- PrivateKeyFile string
- AthenzDomain string
+ CACertificate []byte
+ KeyPair []tls.Certificate
+ TrustAll bool
+
+ CACertificateFile string
+ CertificateFile string
+ PrivateKeyFile string
}
// LogOptions configures the log output to produce when writing log messages.
@@ -90,17 +97,15 @@ type LogOptions struct {
Level int
}
-// Do sends request to this service. Any required authentication happens automatically.
+// Do sends request to this service. Authentication of the request happens automatically.
func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Response, error) {
- if s.TLSOptions.AthenzDomain != "" && s.TLSOptions.KeyPair != nil {
- accessToken, err := s.zts.AccessToken(s.TLSOptions.AthenzDomain, s.TLSOptions.KeyPair[0])
- if err != nil {
+ s.once.Do(func() {
+ util.ConfigureTLS(s.httpClient, s.TLSOptions.KeyPair, s.TLSOptions.CACertificate, s.TLSOptions.TrustAll)
+ })
+ if s.auth != nil {
+ if err := s.auth.Authenticate(request); err != nil {
return nil, err
}
- if request.Header == nil {
- request.Header = make(http.Header)
- }
- request.Header.Add("Authorization", "Bearer "+accessToken)
}
return s.httpClient.Do(request, timeout)
}
@@ -118,7 +123,7 @@ func (s *Service) Wait(timeout time.Duration) (int, error) {
default:
return 0, fmt.Errorf("invalid service: %s", s.Name)
}
- return waitForOK(s.httpClient, url, s.TLSOptions.KeyPair, timeout)
+ return waitForOK(s, url, timeout)
}
func (s *Service) Description() string {
@@ -133,27 +138,40 @@ func (s *Service) Description() string {
return fmt.Sprintf("No description of service %s", s.Name)
}
-func isOK(status int) bool { return status/100 == 2 }
+func isOK(status int) (bool, error) {
+ class := status / 100
+ switch class {
+ case 2: // success
+ return true, nil
+ case 4: // client error
+ return false, fmt.Errorf("request failed with status %d", status)
+ default: // retry
+ return false, nil
+ }
+}
type responseFunc func(status int, response []byte) (bool, error)
type requestFunc func() *http.Request
-// waitForOK queries url and returns its status code. If the url returns a non-200 status code, it is repeatedly queried
+// waitForOK queries url and returns its status code. If response status is not 2xx or 4xx, it is repeatedly queried
// until timeout elapses.
-func waitForOK(client util.HTTPClient, url string, certificates []tls.Certificate, timeout time.Duration) (int, error) {
+func waitForOK(service *Service, url string, timeout time.Duration) (int, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return 0, err
}
- okFunc := func(status int, response []byte) (bool, error) { return isOK(status), nil }
- return wait(client, okFunc, func() *http.Request { return req }, certificates, timeout)
+ okFunc := func(status int, response []byte) (bool, error) {
+ ok, err := isOK(status)
+ if err != nil {
+ return false, fmt.Errorf("failed to query %s at %s: %w", service.Description(), url, err)
+ }
+ return ok, err
+ }
+ return wait(service, okFunc, func() *http.Request { return req }, timeout)
}
-func wait(client util.HTTPClient, fn responseFunc, reqFn requestFunc, certificates []tls.Certificate, timeout time.Duration) (int, error) {
- if certificates != nil {
- util.SetCertificates(client, certificates)
- }
+func wait(service *Service, fn responseFunc, reqFn requestFunc, timeout time.Duration) (int, error) {
var (
httpErr error
response *http.Response
@@ -163,7 +181,7 @@ func wait(client util.HTTPClient, fn responseFunc, reqFn requestFunc, certificat
loopOnce := timeout == 0
for time.Now().Before(deadline) || loopOnce {
req := reqFn()
- response, httpErr = client.Do(req, 10*time.Second)
+ response, httpErr = service.Do(req, 10*time.Second)
if httpErr == nil {
statusCode = response.StatusCode
body, err := io.ReadAll(response.Body)
diff --git a/client/go/internal/vespa/target_cloud.go b/client/go/internal/vespa/target_cloud.go
index 1fb3edd78c5..928bb788494 100644
--- a/client/go/internal/vespa/target_cloud.go
+++ b/client/go/internal/vespa/target_cloud.go
@@ -2,7 +2,6 @@ package vespa
import (
"bytes"
- "crypto/tls"
"encoding/json"
"fmt"
"math"
@@ -35,8 +34,8 @@ type cloudTarget struct {
deploymentOptions CloudDeploymentOptions
logOptions LogOptions
httpClient util.HTTPClient
- zts zts
- auth0 auth0
+ apiAuth Authenticator
+ deploymentAuth Authenticator
}
type deploymentEndpoint struct {
@@ -62,23 +61,15 @@ type logMessage struct {
Message string `json:"message"`
}
-type zts interface {
- AccessToken(domain string, certficiate tls.Certificate) (string, error)
-}
-
-type auth0 interface {
- AccessToken() (string, error)
-}
-
// CloudTarget creates a Target for the Vespa Cloud or hosted Vespa platform.
-func CloudTarget(httpClient util.HTTPClient, ztsClient zts, auth0Client auth0, apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) {
+func CloudTarget(httpClient util.HTTPClient, apiAuth Authenticator, deploymentAuth Authenticator, apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) {
return &cloudTarget{
httpClient: httpClient,
apiOptions: apiOptions,
deploymentOptions: deploymentOptions,
logOptions: logOptions,
- zts: ztsClient,
- auth0: auth0Client,
+ apiAuth: apiAuth,
+ deploymentAuth: deploymentAuth,
}, nil
}
@@ -118,25 +109,25 @@ func (t *cloudTarget) IsCloud() bool { return true }
func (t *cloudTarget) Deployment() Deployment { return t.deploymentOptions.Deployment }
func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, cluster string) (*Service, error) {
- var service *Service
switch name {
case DeployService:
- service = &Service{
+ service := &Service{
Name: name,
BaseURL: t.apiOptions.System.URL,
TLSOptions: t.apiOptions.TLSOptions,
- zts: t.zts,
httpClient: t.httpClient,
+ auth: t.apiAuth,
}
if timeout > 0 {
status, err := service.Wait(timeout)
if err != nil {
return nil, err
}
- if !isOK(status) {
+ if ok, _ := isOK(status); !ok {
return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL)
}
}
+ return service, nil
case QueryService, DocumentService:
if t.deploymentOptions.ClusterURLs == nil {
if err := t.waitForEndpoints(timeout, runID); err != nil {
@@ -147,38 +138,15 @@ func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, c
if err != nil {
return nil, err
}
- t.deploymentOptions.TLSOptions.AthenzDomain = t.apiOptions.System.AthenzDomain
- service = &Service{
+ return &Service{
Name: name,
BaseURL: url,
TLSOptions: t.deploymentOptions.TLSOptions,
- zts: t.zts,
httpClient: t.httpClient,
- }
-
+ auth: t.deploymentAuth,
+ }, nil
default:
return nil, fmt.Errorf("unknown service: %s", name)
-
- }
- if service.TLSOptions.KeyPair != nil {
- util.SetCertificates(service.httpClient, service.TLSOptions.KeyPair)
- }
- return service, nil
-}
-
-func (t *cloudTarget) SignRequest(req *http.Request, keyID string) error {
- if t.apiOptions.System.IsPublic() {
- if t.apiOptions.APIKey != nil {
- signer := NewRequestSigner(keyID, t.apiOptions.APIKey)
- return signer.SignRequest(req)
- } else {
- return t.addAuth0AccessToken(req)
- }
- } else {
- if t.apiOptions.TLSOptions.KeyPair == nil {
- return fmt.Errorf("system %s requires a certificate for authentication", t.apiOptions.System.Name)
- }
- return nil
}
}
@@ -190,7 +158,11 @@ func (t *cloudTarget) CheckVersion(clientVersion version.Version) error {
if err != nil {
return err
}
- response, err := t.httpClient.Do(req, 10*time.Second)
+ deployService, err := t.Service(DeployService, 0, 0, "")
+ if err != nil {
+ return err
+ }
+ response, err := deployService.Do(req, 10*time.Second)
if err != nil {
return err
}
@@ -212,18 +184,6 @@ func (t *cloudTarget) CheckVersion(clientVersion version.Version) error {
return nil
}
-func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error {
- accessToken, err := t.auth0.AccessToken()
- if err != nil {
- return err
- }
- if request.Header == nil {
- request.Header = make(http.Header)
- }
- request.Header.Set("Authorization", "Bearer "+accessToken)
- return nil
-}
-
func (t *cloudTarget) logsURL() string {
return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs",
t.apiOptions.System.URL,
@@ -246,11 +206,10 @@ func (t *cloudTarget) PrintLog(options LogOptions) error {
q.Set("to", strconv.FormatInt(toMillis, 10))
}
req.URL.RawQuery = q.Encode()
- t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm())
return req
}
logFunc := func(status int, response []byte) (bool, error) {
- if ok, err := isCloudOK(status); !ok {
+ if ok, err := isOK(status); !ok {
return ok, err
}
logEntries, err := ReadLogEntries(bytes.NewReader(response))
@@ -275,10 +234,18 @@ func (t *cloudTarget) PrintLog(options LogOptions) error {
if options.Follow {
timeout = math.MaxInt64 // No timeout
}
- _, err = wait(t.httpClient, logFunc, requestFunc, t.apiOptions.TLSOptions.KeyPair, timeout)
+ _, err = t.deployServiceWait(logFunc, requestFunc, timeout)
return err
}
+func (t *cloudTarget) deployServiceWait(fn responseFunc, reqFn requestFunc, timeout time.Duration) (int, error) {
+ deployService, err := t.Service(DeployService, 0, 0, "")
+ if err != nil {
+ return 0, err
+ }
+ return wait(deployService, fn, reqFn, timeout)
+}
+
func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error {
if runID > 0 {
if err := t.waitForRun(runID, timeout); err != nil {
@@ -302,13 +269,10 @@ func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error {
q := req.URL.Query()
q.Set("after", strconv.FormatInt(lastID, 10))
req.URL.RawQuery = q.Encode()
- if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil {
- util.JustExitWith(err)
- }
return req
}
jobSuccessFunc := func(status int, response []byte) (bool, error) {
- if ok, err := isCloudOK(status); !ok {
+ if ok, err := isOK(status); !ok {
return ok, err
}
var resp jobResponse
@@ -326,7 +290,7 @@ func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error {
}
return true, nil
}
- _, err = wait(t.httpClient, jobSuccessFunc, requestFunc, t.apiOptions.TLSOptions.KeyPair, timeout)
+ _, err = t.deployServiceWait(jobSuccessFunc, requestFunc, timeout)
return err
}
@@ -361,12 +325,9 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error {
if err != nil {
return err
}
- if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil {
- return err
- }
urlsByCluster := make(map[string]string)
endpointFunc := func(status int, response []byte) (bool, error) {
- if ok, err := isCloudOK(status); !ok {
+ if ok, err := isOK(status); !ok {
return ok, err
}
var resp deploymentResponse
@@ -384,7 +345,7 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error {
}
return true, nil
}
- if _, err = wait(t.httpClient, endpointFunc, func() *http.Request { return req }, t.apiOptions.TLSOptions.KeyPair, timeout); err != nil {
+ if _, err := t.deployServiceWait(endpointFunc, func() *http.Request { return req }, timeout); err != nil {
return err
}
if len(urlsByCluster) == 0 {
@@ -393,11 +354,3 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error {
t.deploymentOptions.ClusterURLs = urlsByCluster
return nil
}
-
-func isCloudOK(status int) (bool, error) {
- if status == 401 {
- // when retrying we should give up immediately if we're not authorized
- return false, fmt.Errorf("status %d: invalid credentials", status)
- }
- return isOK(status), nil
-}
diff --git a/client/go/internal/vespa/target_custom.go b/client/go/internal/vespa/target_custom.go
index 848d19f0a90..0a3a9d48fed 100644
--- a/client/go/internal/vespa/target_custom.go
+++ b/client/go/internal/vespa/target_custom.go
@@ -15,6 +15,7 @@ type customTarget struct {
targetType string
baseURL string
httpClient util.HTTPClient
+ tlsOptions TLSOptions
}
type serviceConvergeResponse struct {
@@ -22,13 +23,13 @@ type serviceConvergeResponse struct {
}
// LocalTarget creates a target for a Vespa platform running locally.
-func LocalTarget(httpClient util.HTTPClient) Target {
- return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1", httpClient: httpClient}
+func LocalTarget(httpClient util.HTTPClient, tlsOptions TLSOptions) Target {
+ return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1", httpClient: httpClient, tlsOptions: tlsOptions}
}
// CustomTarget creates a Target for a Vespa platform running at baseURL.
-func CustomTarget(httpClient util.HTTPClient, baseURL string) Target {
- return &customTarget{targetType: TargetCustom, baseURL: baseURL, httpClient: httpClient}
+func CustomTarget(httpClient util.HTTPClient, baseURL string, tlsOptions TLSOptions) Target {
+ return &customTarget{targetType: TargetCustom, baseURL: baseURL, httpClient: httpClient, tlsOptions: tlsOptions}
}
func (t *customTarget) Type() string { return t.targetType }
@@ -44,7 +45,7 @@ func (t *customTarget) createService(name string) (*Service, error) {
if err != nil {
return nil, err
}
- return &Service{BaseURL: url, Name: name, httpClient: t.httpClient}, nil
+ return &Service{BaseURL: url, Name: name, httpClient: t.httpClient, TLSOptions: t.tlsOptions}, nil
}
return nil, fmt.Errorf("unknown service: %s", name)
}
@@ -60,7 +61,7 @@ func (t *customTarget) Service(name string, timeout time.Duration, sessionOrRunI
if err != nil {
return nil, err
}
- if !isOK(status) {
+ if ok, _ := isOK(status); !ok {
return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL)
}
} else {
@@ -76,8 +77,6 @@ func (t *customTarget) PrintLog(options LogOptions) error {
return fmt.Errorf("log access is only supported on cloud: run vespa-logfmt on the admin node instead")
}
-func (t *customTarget) SignRequest(req *http.Request, sigKeyId string) error { return nil }
-
func (t *customTarget) CheckVersion(version version.Version) error { return nil }
func (t *customTarget) urlWithPort(serviceName string) (string, error) {
@@ -101,19 +100,19 @@ func (t *customTarget) urlWithPort(serviceName string) (string, error) {
}
func (t *customTarget) waitForConvergence(timeout time.Duration) error {
- deployURL, err := t.urlWithPort(DeployService)
+ deployService, err := t.createService(DeployService)
if err != nil {
return err
}
- url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployURL)
+ url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployService.BaseURL)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
converged := false
convergedFunc := func(status int, response []byte) (bool, error) {
- if !isOK(status) {
- return false, nil
+ if ok, err := isOK(status); !ok {
+ return ok, err
}
var resp serviceConvergeResponse
if err := json.Unmarshal(response, &resp); err != nil {
@@ -122,7 +121,7 @@ func (t *customTarget) waitForConvergence(timeout time.Duration) error {
converged = resp.Converged
return converged, nil
}
- if _, err := wait(t.httpClient, convergedFunc, func() *http.Request { return req }, nil, timeout); err != nil {
+ if _, err := wait(deployService, convergedFunc, func() *http.Request { return req }, timeout); err != nil {
return err
}
if !converged {
diff --git a/client/go/internal/vespa/target_test.go b/client/go/internal/vespa/target_test.go
index b9d65f3d8a4..bf266e8f9ec 100644
--- a/client/go/internal/vespa/target_test.go
+++ b/client/go/internal/vespa/target_test.go
@@ -3,7 +3,6 @@ package vespa
import (
"bytes"
- "crypto/tls"
"fmt"
"io"
"net/http"
@@ -19,10 +18,16 @@ import (
type mockVespaApi struct {
deploymentConverged bool
+ authFailure bool
serverURL string
}
func (v *mockVespaApi) mockVespaHandler(w http.ResponseWriter, req *http.Request) {
+ if v.authFailure {
+ response := `{"message":"unauthorized"}`
+ w.WriteHeader(401)
+ w.Write([]byte(response))
+ }
switch req.URL.Path {
case "/cli/v1/":
response := `{"minVersion":"8.0.0"}`
@@ -65,17 +70,17 @@ func (v *mockVespaApi) mockVespaHandler(w http.ResponseWriter, req *http.Request
}
func TestCustomTarget(t *testing.T) {
- lt := LocalTarget(&mock.HTTPClient{})
+ lt := LocalTarget(&mock.HTTPClient{}, TLSOptions{})
assertServiceURL(t, "http://127.0.0.1:19071", lt, "deploy")
assertServiceURL(t, "http://127.0.0.1:8080", lt, "query")
assertServiceURL(t, "http://127.0.0.1:8080", lt, "document")
- ct := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42")
+ ct := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42", TLSOptions{})
assertServiceURL(t, "http://192.0.2.42:19071", ct, "deploy")
assertServiceURL(t, "http://192.0.2.42:8080", ct, "query")
assertServiceURL(t, "http://192.0.2.42:8080", ct, "document")
- ct2 := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42:60000")
+ ct2 := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42:60000", TLSOptions{})
assertServiceURL(t, "http://192.0.2.42:60000", ct2, "deploy")
assertServiceURL(t, "http://192.0.2.42:60000", ct2, "query")
assertServiceURL(t, "http://192.0.2.42:60000", ct2, "document")
@@ -85,7 +90,7 @@ func TestCustomTargetWait(t *testing.T) {
vc := mockVespaApi{}
srv := httptest.NewServer(http.HandlerFunc(vc.mockVespaHandler))
defer srv.Close()
- target := CustomTarget(util.CreateClient(time.Second*10), srv.URL)
+ target := CustomTarget(util.CreateClient(time.Second*10), srv.URL, TLSOptions{})
_, err := target.Service("query", time.Millisecond, 42, "")
assert.NotNil(t, err)
@@ -107,6 +112,9 @@ func TestCloudTargetWait(t *testing.T) {
var logWriter bytes.Buffer
target := createCloudTarget(t, srv.URL, &logWriter)
+ vc.authFailure = true
+ assertServiceWaitErr(t, 401, true, target, "deploy")
+ vc.authFailure = false
assertServiceWait(t, 200, target, "deploy")
_, err := target.Service("query", time.Millisecond, 42, "")
@@ -157,10 +165,11 @@ func createCloudTarget(t *testing.T, url string, logWriter io.Writer) Target {
apiKey, err := CreateAPIKey()
assert.Nil(t, err)
+ auth := &mockAuthenticator{}
target, err := CloudTarget(
util.CreateClient(time.Second*10),
- &mockZTS{},
- &mockAuth0{},
+ auth,
+ auth,
APIOptions{APIKey: apiKey, System: PublicSystem},
CloudDeploymentOptions{
Deployment: Deployment{
@@ -175,7 +184,6 @@ func createCloudTarget(t *testing.T, url string, logWriter io.Writer) Target {
}
if ct, ok := target.(*cloudTarget); ok {
ct.apiOptions.System.URL = url
- ct.zts = &mockZTS{token: "foo bar"}
} else {
t.Fatalf("Wrong target type %T", ct)
}
@@ -189,22 +197,22 @@ func assertServiceURL(t *testing.T, url string, target Target, service string) {
}
func assertServiceWait(t *testing.T, expectedStatus int, target Target, service string) {
+ assertServiceWaitErr(t, expectedStatus, false, target, service)
+}
+
+func assertServiceWaitErr(t *testing.T, expectedStatus int, expectErr bool, target Target, service string) {
s, err := target.Service(service, 0, 42, "")
assert.Nil(t, err)
status, err := s.Wait(0)
- assert.Nil(t, err)
+ if expectErr {
+ assert.NotNil(t, err)
+ } else {
+ assert.Nil(t, err)
+ }
assert.Equal(t, expectedStatus, status)
}
-type mockZTS struct{ token string }
-
-func (c *mockZTS) AccessToken(domain string, certificate tls.Certificate) (string, error) {
- return c.token, nil
-}
-
-type mockAuth0 struct{}
-
-func (a *mockAuth0) AccessToken() (string, error) { return "", nil }
+type mockAuthenticator struct{}
-func (a *mockAuth0) HasCredentials() bool { return true }
+func (a *mockAuthenticator) Authenticate(request *http.Request) error { return nil }
diff --git a/client/pom.xml b/client/pom.xml
index a310e7d6feb..6da6dc74a82 100644
--- a/client/pom.xml
+++ b/client/pom.xml
@@ -35,6 +35,12 @@
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<release>${vespaClients.jdk.releaseVersion}</release>
+ <compilerArgs> <!-- Remove (to use default) when not compiling for 8 -->
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-rawtypes</arg>
+ <arg>-Xlint:-unchecked</arg>
+ <arg>-Xlint:-serial</arg>
+ </compilerArgs>
</configuration>
</plugin>
<plugin>
diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml
index 4d5d801e0e3..69e1a94a813 100644
--- a/cloud-tenant-base-dependencies-enforcer/pom.xml
+++ b/cloud-tenant-base-dependencies-enforcer/pom.xml
@@ -30,11 +30,12 @@
<httpcore.version>4.4.16</httpcore.version>
<junit5.version>5.8.1</junit5.version> <!-- TODO: in parent this is named 'junit.version' -->
<onnxruntime.version>1.13.1</onnxruntime.version>
+ <openai-gpt3.version>0.12.0</openai-gpt3.version>
<!-- END parent/pom.xml -->
<!-- ALL BELOW MUST BE KEPT IN SYNC WITH container-dependency-versions pom
- Copied here because vz-tenant-base does not have a parent. -->
+ Copied here because cloud-tenant-base does not have a parent. -->
<aopalliance.version>1.0</aopalliance.version>
<guava.version>27.1-jre</guava.version>
<guice.version>4.2.3</guice.version>
@@ -234,6 +235,18 @@
<include>org.osgi:org.osgi.compendium:4.1.0:test</include>
<include>org.osgi:org.osgi.core:4.1.0:test</include>
<include>xerces:xercesImpl:2.12.2:test</include>
+
+ <include>com.squareup.okhttp3:okhttp:3.14.9:test</include>
+ <include>com.squareup.okio:okio:1.17.2:test</include>
+ <include>com.squareup.retrofit2:adapter-rxjava2:2.9.0:test</include>
+ <include>com.squareup.retrofit2:converter-jackson:2.9.0:test</include>
+ <include>com.squareup.retrofit2:retrofit:2.9.0:test</include>
+ <include>com.theokanning.openai-gpt3-java:api:${openai-gpt3.version}:test</include>
+ <include>com.theokanning.openai-gpt3-java:client:${openai-gpt3.version}:test</include>
+ <include>com.theokanning.openai-gpt3-java:service:${openai-gpt3.version}:test</include>
+ <include>io.reactivex.rxjava2:rxjava:2.0.0:test</include>
+ <include>org.reactivestreams:reactive-streams:1.0.3:test</include>
+
</allowed>
</enforceDependencies>
</rules>
diff --git a/clustercontroller-core/pom.xml b/clustercontroller-core/pom.xml
index b4ac5ca869c..647d8ca4e64 100644
--- a/clustercontroller-core/pom.xml
+++ b/clustercontroller-core/pom.xml
@@ -64,6 +64,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.junit.jupiter</groupId>
+ <artifactId>junit-jupiter-params</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<scope>provided</scope>
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java
index b83d70b8656..2535589395d 100644
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java
+++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java
@@ -32,7 +32,7 @@ public class ContentCluster {
private final int maxNumberOfGroupsAllowedToBeDown;
public ContentCluster(String clusterName, Collection<ConfiguredNode> configuredNodes, Distribution distribution) {
- this(clusterName, configuredNodes, distribution, 1);
+ this(clusterName, configuredNodes, distribution, -1);
}
public ContentCluster(FleetControllerOptions options) {
@@ -40,9 +40,9 @@ public class ContentCluster {
}
ContentCluster(String clusterName,
- Collection<ConfiguredNode> configuredNodes,
- Distribution distribution,
- int maxNumberOfGroupsAllowedToBeDown) {
+ Collection<ConfiguredNode> configuredNodes,
+ Distribution distribution,
+ int maxNumberOfGroupsAllowedToBeDown) {
if (configuredNodes == null) throw new IllegalArgumentException("Nodes must be set");
this.clusterName = clusterName;
this.distribution = distribution;
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java
index e242833fd0c..c823c94afd1 100644
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java
+++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java
@@ -13,13 +13,20 @@ import com.yahoo.vespa.clustercontroller.core.hostinfo.HostInfo;
import com.yahoo.vespa.clustercontroller.core.hostinfo.Metrics;
import com.yahoo.vespa.clustercontroller.core.hostinfo.StorageNode;
import com.yahoo.vespa.clustercontroller.utils.staterestapi.requests.SetUnitStateRequest;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
+import java.util.Set;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
import static com.yahoo.vdslib.state.NodeType.STORAGE;
import static com.yahoo.vdslib.state.State.DOWN;
+import static com.yahoo.vdslib.state.State.MAINTENANCE;
import static com.yahoo.vdslib.state.State.RETIRED;
import static com.yahoo.vdslib.state.State.UP;
import static com.yahoo.vespa.clustercontroller.core.NodeStateChangeChecker.Result.allowSettingOfWantedState;
@@ -27,6 +34,7 @@ import static com.yahoo.vespa.clustercontroller.core.NodeStateChangeChecker.Resu
import static com.yahoo.vespa.clustercontroller.core.NodeStateChangeChecker.Result.createDisallowed;
import static com.yahoo.vespa.clustercontroller.utils.staterestapi.requests.SetUnitStateRequest.Condition.FORCE;
import static com.yahoo.vespa.clustercontroller.utils.staterestapi.requests.SetUnitStateRequest.Condition.SAFE;
+import static java.util.logging.Level.FINE;
/**
* Checks if a node can be upgraded.
@@ -35,8 +43,9 @@ import static com.yahoo.vespa.clustercontroller.utils.staterestapi.requests.SetU
*/
public class NodeStateChangeChecker {
- public static final String BUCKETS_METRIC_NAME = "vds.datastored.bucket_space.buckets_total";
- public static final Map<String, String> BUCKETS_METRIC_DIMENSIONS = Map.of("bucketSpace", "default");
+ private static final Logger log = Logger.getLogger(NodeStateChangeChecker.class.getName());
+ private static final String BUCKETS_METRIC_NAME = "vds.datastored.bucket_space.buckets_total";
+ private static final Map<String, String> BUCKETS_METRIC_DIMENSIONS = Map.of("bucketSpace", "default");
private final int requiredRedundancy;
private final HierarchicalGroupVisiting groupVisiting;
@@ -50,6 +59,8 @@ public class NodeStateChangeChecker {
this.clusterInfo = cluster.clusterInfo();
this.inMoratorium = inMoratorium;
this.maxNumberOfGroupsAllowedToBeDown = cluster.maxNumberOfGroupsAllowedToBeDown();
+ if ( ! groupVisiting.isHierarchical() && maxNumberOfGroupsAllowedToBeDown > 1)
+ throw new IllegalArgumentException("Cannot have both 1 group and maxNumberOfGroupsAllowedToBeDown > 1");
}
public static class Result {
@@ -214,26 +225,34 @@ public class NodeStateChangeChecker {
oldWantedState.getState() + ": " + oldWantedState.getDescription());
}
- Result otherGroupCheck = anotherNodeInAnotherGroupHasWantedState(nodeInfo);
- if (!otherGroupCheck.settingWantedStateIsAllowed()) {
- return otherGroupCheck;
+ if (maxNumberOfGroupsAllowedToBeDown == -1) {
+ var otherGroupCheck = anotherNodeInAnotherGroupHasWantedState(nodeInfo);
+ if (!otherGroupCheck.settingWantedStateIsAllowed()) {
+ return otherGroupCheck;
+ }
+ if (anotherNodeInGroupAlreadyAllowed(nodeInfo, newDescription)) {
+ return allowSettingOfWantedState();
+ }
+ } else {
+ var result = otherNodesHaveWantedState(nodeInfo, newDescription, clusterState);
+ if (result.isPresent())
+ return result.get();
}
if (clusterState.getNodeState(nodeInfo.getNode()).getState() == DOWN) {
- return allowSettingOfWantedState();
- }
-
- if (anotherNodeInGroupAlreadyAllowed(nodeInfo, newDescription)) {
+ log.log(FINE, "node is DOWN, allow");
return allowSettingOfWantedState();
}
Result allNodesAreUpCheck = checkAllNodesAreUp(clusterState);
if (!allNodesAreUpCheck.settingWantedStateIsAllowed()) {
+ log.log(FINE, "allNodesAreUpCheck: " + allNodesAreUpCheck);
return allNodesAreUpCheck;
}
Result checkDistributorsResult = checkDistributors(nodeInfo.getNode(), clusterState.getVersion());
if (!checkDistributorsResult.settingWantedStateIsAllowed()) {
+ log.log(FINE, "checkDistributors: "+ checkDistributorsResult);
return checkDistributorsResult;
}
@@ -268,6 +287,65 @@ public class NodeStateChangeChecker {
}
}
+ /**
+ * Returns an optional Result, where return value is:
+ * For flat setup: Return Optional.of(disallowed) if wanted state is set on some node, else Optional.empty
+ * For hierarchical setup: No wanted state for other nodes, return Optional.empty
+ * Wanted state for nodes/groups are not UP:
+ * if less than maxNumberOfGroupsAllowedToBeDown: return Optional.of(allowed)
+ * else: if node is in group with nodes already down: return Optional.of(allowed), else Optional.of(disallowed)
+ */
+ private Optional<Result> otherNodesHaveWantedState(StorageNodeInfo nodeInfo, String newDescription, ClusterState clusterState) {
+ Node node = nodeInfo.getNode();
+
+ if (groupVisiting.isHierarchical()) {
+ Set<Integer> groupsWithNodesWantedStateNotUp = groupsWithUserWantedStateNotUp();
+ if (groupsWithNodesWantedStateNotUp.size() == 0) {
+ log.log(FINE, "groupsWithNodesWantedStateNotUp=0");
+ return Optional.empty();
+ }
+
+ Set<Integer> groupsWithSameStateAndDescription = groupsWithSameStateAndDescription(MAINTENANCE, newDescription);
+ if (aGroupContainsNode(groupsWithSameStateAndDescription, node)) {
+ log.log(FINE, "Node is in group with same state and description, allow");
+ return Optional.of(allowSettingOfWantedState());
+ }
+ // There are groups with nodes not up, but with another description, probably operator set
+ if (groupsWithSameStateAndDescription.size() == 0) {
+ return Optional.of(createDisallowed("Wanted state already set for another node in groups: " +
+ sortSetIntoList(groupsWithNodesWantedStateNotUp)));
+ }
+
+ Set<Integer> retiredAndNotUpGroups = groupsWithNotRetiredAndNotUp(clusterState);
+ int numberOfGroupsToConsider = retiredAndNotUpGroups.size();
+ // Subtract one group if node is in a group with nodes already retired or not up, since number of such groups will
+ // not increase if we allow node to go down
+ if (aGroupContainsNode(retiredAndNotUpGroups, node)) {
+ numberOfGroupsToConsider = retiredAndNotUpGroups.size() - 1;
+ }
+ if (numberOfGroupsToConsider < maxNumberOfGroupsAllowedToBeDown) {
+ log.log(FINE, "Allow, retiredAndNotUpGroups=" + retiredAndNotUpGroups);
+ return Optional.of(allowSettingOfWantedState());
+ }
+
+ return Optional.of(createDisallowed(String.format("At most %d groups can have wanted state: %s",
+ maxNumberOfGroupsAllowedToBeDown,
+ sortSetIntoList(retiredAndNotUpGroups))));
+ } else {
+ // Return a disallow-result if there is another node with a wanted state
+ var otherNodeHasWantedState = otherNodeHasWantedState(nodeInfo);
+ if ( ! otherNodeHasWantedState.settingWantedStateIsAllowed())
+ return Optional.of(otherNodeHasWantedState);
+ }
+ return Optional.empty();
+ }
+
+ private ArrayList<Integer> sortSetIntoList(Set<Integer> set) {
+ var sortedList = new ArrayList<>(set);
+ Collections.sort(sortedList);
+ return sortedList;
+ }
+
/** Returns a disallow-result, if there is a node in the group with wanted state != UP. */
private Result otherNodeInGroupHasWantedState(Group group) {
for (var configuredNode : group.getNodes()) {
@@ -354,6 +432,22 @@ public class NodeStateChangeChecker {
return false;
}
+ private boolean aGroupContainsNode(Collection<Integer> groupIndexes, Node node) {
+ for (Group group : getGroupsWithIndexes(groupIndexes)) {
+ if (groupContainsNode(group, node))
+ return true;
+ }
+
+ return false;
+ }
+
+ private List<Group> getGroupsWithIndexes(Collection<Integer> groupIndexes) {
+ return clusterInfo.getStorageNodeInfos().stream()
+ .map(NodeInfo::getGroup)
+ .filter(group -> groupIndexes.contains(group.getIndex()))
+ .collect(Collectors.toList());
+ }
+
private Result checkAllNodesAreUp(ClusterState clusterState) {
// This method verifies both storage nodes and distributors are up (or retired).
// The complicated part is making a summary error message.
@@ -441,4 +535,43 @@ public class NodeStateChangeChecker {
return allowSettingOfWantedState();
}
+ private Set<Integer> groupsWithUserWantedStateNotUp() {
+ return clusterInfo.getAllNodeInfos().stream()
+ .filter(sni -> !UP.equals(sni.getUserWantedState().getState()))
+ .map(NodeInfo::getGroup)
+ .filter(Objects::nonNull)
+ .filter(Group::isLeafGroup)
+ .map(Group::getIndex)
+ .collect(Collectors.toSet());
+ }
+
+ // groups with at least one node with the same state & description
+ private Set<Integer> groupsWithSameStateAndDescription(State state, String newDescription) {
+ return clusterInfo.getAllNodeInfos().stream()
+ .filter(nodeInfo -> {
+ var userWantedState = nodeInfo.getUserWantedState();
+ return userWantedState.getState() == state &&
+ Objects.equals(userWantedState.getDescription(), newDescription);
+ })
+ .map(NodeInfo::getGroup)
+ .filter(Objects::nonNull)
+ .filter(Group::isLeafGroup)
+ .map(Group::getIndex)
+ .collect(Collectors.toSet());
+ }
+
+ // groups with at least one node in state (not retired AND not up)
+ private Set<Integer> groupsWithNotRetiredAndNotUp(ClusterState clusterState) {
+ return clusterInfo.getAllNodeInfos().stream()
+ .filter(nodeInfo -> (nodeInfo.getUserWantedState().getState() != RETIRED
+ && nodeInfo.getUserWantedState().getState() != UP)
+ || (clusterState.getNodeState(nodeInfo.getNode()).getState() != RETIRED
+ && clusterState.getNodeState(nodeInfo.getNode()).getState() != UP))
+ .map(NodeInfo::getGroup)
+ .filter(Objects::nonNull)
+ .filter(Group::isLeafGroup)
+ .map(Group::getIndex)
+ .collect(Collectors.toSet());
+ }
+
}
diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeCheckerTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeCheckerTest.java
index 45ca07c88e4..c4fd7cb69b9 100644
--- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeCheckerTest.java
+++ b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeCheckerTest.java
@@ -10,7 +10,8 @@ import com.yahoo.vdslib.state.State;
import com.yahoo.vespa.clustercontroller.core.hostinfo.HostInfo;
import com.yahoo.vespa.config.content.StorDistributionConfig;
import org.junit.jupiter.api.Test;
-import java.text.ParseException;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
@@ -45,13 +46,7 @@ public class NodeStateChangeCheckerTest {
return new NodeState(STORAGE, state).setDescription(description);
}
- private static ClusterState clusterState(String state) {
- try {
- return new ClusterState(state);
- } catch (ParseException e) {
- throw new RuntimeException(e);
- }
- }
+ private static ClusterState clusterState(String state) { return ClusterState.stateFromString(state); }
private static ClusterState defaultAllUpClusterState() {
return defaultAllUpClusterState(4);
@@ -68,14 +63,14 @@ public class NodeStateChangeCheckerTest {
return new NodeStateChangeChecker(cluster, false);
}
- private ContentCluster createCluster(int nodeCount) {
- return createCluster(nodeCount, 1);
+ private ContentCluster createCluster(int nodeCount, int maxNumberOfGroupsAllowedToBeDown) {
+ return createCluster(nodeCount, 1, maxNumberOfGroupsAllowedToBeDown);
}
- private ContentCluster createCluster(int nodeCount, int groupCount) {
- Collection<ConfiguredNode> nodes = createNodes(nodeCount);
+ private ContentCluster createCluster(int nodeCount, int groupCount, int maxNumberOfGroupsAllowedToBeDown) {
+ List<ConfiguredNode> nodes = createNodes(nodeCount);
Distribution distribution = new Distribution(createDistributionConfig(nodeCount, groupCount));
- return new ContentCluster("Clustername", nodes, distribution);
+ return new ContentCluster("Clustername", nodes, distribution, maxNumberOfGroupsAllowedToBeDown);
}
private String createDistributorHostInfo(int replicationfactor1, int replicationfactor2, int replicationfactor3) {
@@ -113,9 +108,10 @@ public class NodeStateChangeCheckerTest {
}
}
- @Test
- void testCanUpgradeForce() {
- var nodeStateChangeChecker = createChangeChecker(createCluster(1));
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCanUpgradeWithForce(int maxNumberOfGroupsAllowedToBeDown) {
+ var nodeStateChangeChecker = createChangeChecker(createCluster(1, maxNumberOfGroupsAllowedToBeDown));
NodeState newState = new NodeState(STORAGE, INITIALIZING);
Result result = nodeStateChangeChecker.evaluateTransition(
nodeDistributor, defaultAllUpClusterState(), FORCE,
@@ -124,9 +120,10 @@ public class NodeStateChangeCheckerTest {
assertFalse(result.wantedStateAlreadySet());
}
- @Test
- void testDeniedInMoratorium() {
- ContentCluster cluster = createCluster(4);
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testDeniedInMoratorium(int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
var nodeStateChangeChecker = new NodeStateChangeChecker(cluster, true);
Result result = nodeStateChangeChecker.evaluateTransition(
new Node(STORAGE, 10), defaultAllUpClusterState(), SAFE,
@@ -136,9 +133,10 @@ public class NodeStateChangeCheckerTest {
assertEquals("Master cluster controller is bootstrapping and in moratorium", result.getReason());
}
- @Test
- void testUnknownStorageNode() {
- ContentCluster cluster = createCluster(4);
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testUnknownStorageNode(int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
var nodeStateChangeChecker = createChangeChecker(cluster);
Result result = nodeStateChangeChecker.evaluateTransition(
new Node(STORAGE, 10), defaultAllUpClusterState(), SAFE,
@@ -148,11 +146,12 @@ public class NodeStateChangeCheckerTest {
assertEquals("Unknown node storage.10", result.getReason());
}
- @Test
- void testSafeMaintenanceDisallowedWhenOtherStorageNodeInFlatClusterIsSuspended() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testSafeMaintenanceDisallowedWhenOtherStorageNodeInFlatClusterIsSuspended(int maxNumberOfGroupsAllowedToBeDown) {
// Nodes 0-3, storage node 0 being in maintenance with "Orchestrator" description.
- ContentCluster cluster = createCluster(4);
- cluster.clusterInfo().getStorageNodeInfo(0).setWantedState(new NodeState(STORAGE, MAINTENANCE).setDescription("Orchestrator"));
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
+ setStorageNodeWantedStateToMaintenance(cluster, 0);
var nodeStateChangeChecker = createChangeChecker(cluster);
ClusterState clusterStateWith0InMaintenance = clusterState(String.format(
"version:%d distributor:4 storage:4 .0.s:m",
@@ -168,9 +167,130 @@ public class NodeStateChangeCheckerTest {
}
@Test
- void testSafeMaintenanceDisallowedWhenOtherDistributorInFlatClusterIsSuspended() {
+ void testMaintenanceAllowedFor2Of4Groups() {
+ // 4 groups with 1 node in each group
+ Collection<ConfiguredNode> nodes = createNodes(4);
+ StorDistributionConfig config = createDistributionConfig(4, 4);
+
+ int maxNumberOfGroupsAllowedToBeDown = 2;
+ var cluster = new ContentCluster("Clustername", nodes, new Distribution(config), maxNumberOfGroupsAllowedToBeDown);
+ setAllNodesUp(cluster, HostInfo.createHostInfo(createDistributorHostInfo(4, 5, 6)));
+ var nodeStateChangeChecker = createChangeChecker(cluster);
+
+ // All nodes up, set a storage node in group 0 to maintenance
+ {
+ int nodeIndex = 0;
+ checkSettingToMaintenanceIsAllowed(nodeIndex, nodeStateChangeChecker, defaultAllUpClusterState());
+ setStorageNodeWantedStateToMaintenance(cluster, nodeIndex);
+ }
+
+ // Node in group 0 in maintenance, set storage node in group 1 to maintenance
+ {
+ ClusterState clusterState = clusterState(String.format("version:%d distributor:4 .0.s:d storage:4 .0.s:m", currentClusterStateVersion));
+ int nodeIndex = 1;
+ checkSettingToMaintenanceIsAllowed(nodeIndex, nodeStateChangeChecker, clusterState);
+ setStorageNodeWantedStateToMaintenance(cluster, nodeIndex);
+ }
+
+ // Nodes in group 0 and 1 in maintenance, try to set storage node in group 2 to maintenance while storage node 2 is down, should fail
+ {
+ ClusterState clusterState = clusterState(String.format("version:%d distributor:4 storage:4 .0.s:m .1.s:m .2.s:d", currentClusterStateVersion));
+ int nodeIndex = 2;
+ cluster.clusterInfo().getStorageNodeInfo(nodeIndex).setReportedState(new NodeState(STORAGE, DOWN), 0);
+ Node node = new Node(STORAGE, nodeIndex);
+ Result result = nodeStateChangeChecker.evaluateTransition(node, clusterState, SAFE, UP_NODE_STATE, MAINTENANCE_NODE_STATE);
+ assertFalse(result.settingWantedStateIsAllowed(), result.toString());
+ assertFalse(result.wantedStateAlreadySet());
+ assertEquals("At most 2 groups can have wanted state: [0, 1, 2]", result.getReason());
+ }
+
+ // Nodes in group 0 and 1 in maintenance, try to set storage node in group 2 to maintenance, should fail
+ {
+ ClusterState clusterState = clusterState(String.format("version:%d distributor:4 storage:4 .0.s:m .1.s:m", currentClusterStateVersion));
+ int nodeIndex = 2;
+ Node node = new Node(STORAGE, nodeIndex);
+ Result result = nodeStateChangeChecker.evaluateTransition(node, clusterState, SAFE, UP_NODE_STATE, MAINTENANCE_NODE_STATE);
+ assertFalse(result.settingWantedStateIsAllowed(), result.toString());
+ assertFalse(result.wantedStateAlreadySet());
+ assertEquals("At most 2 groups can have wanted state: [0, 1]", result.getReason());
+ }
+
+ }
+
+ @Test
+ void testMaintenanceAllowedFor2Of4Groups8Nodes() {
+ // 4 groups with 2 nodes in each group
+ Collection<ConfiguredNode> nodes = createNodes(8);
+ StorDistributionConfig config = createDistributionConfig(8, 4);
+
+ int maxNumberOfGroupsAllowedToBeDown = 2;
+ var cluster = new ContentCluster("Clustername", nodes, new Distribution(config), maxNumberOfGroupsAllowedToBeDown);
+ setAllNodesUp(cluster, HostInfo.createHostInfo(createDistributorHostInfo(4, 5, 6)));
+ var nodeStateChangeChecker = createChangeChecker(cluster);
+
+ // All nodes up, set a storage node in group 0 to maintenance
+ {
+ ClusterState clusterState = defaultAllUpClusterState(8);
+ int nodeIndex = 0;
+ checkSettingToMaintenanceIsAllowed(nodeIndex, nodeStateChangeChecker, clusterState);
+ setStorageNodeWantedStateToMaintenance(cluster, nodeIndex);
+ }
+
+ // 1 Node in group 0 in maintenance, try to set node 1 in group 0 to maintenance
+ {
+ ClusterState clusterState = clusterState(String.format("version:%d distributor:8 .0.s:d storage:8 .0.s:m", currentClusterStateVersion));
+ int nodeIndex = 1;
+ checkSettingToMaintenanceIsAllowed(nodeIndex, nodeStateChangeChecker, clusterState);
+ setStorageNodeWantedStateToMaintenance(cluster, nodeIndex);
+ }
+
+ // 2 nodes in group 0 in maintenance, try to set storage node 2 in group 1 to maintenance
+ {
+ ClusterState clusterState = clusterState(String.format("version:%d distributor:8 storage:8 .0.s:m .1.s:m", currentClusterStateVersion));
+ int nodeIndex = 2;
+ checkSettingToMaintenanceIsAllowed(nodeIndex, nodeStateChangeChecker, clusterState);
+ setStorageNodeWantedStateToMaintenance(cluster, nodeIndex);
+ }
+
+ // 2 nodes in group 0 and 1 in group 1 in maintenance, try to set storage node 4 in group 2 to maintenance, should fail (different group)
+ {
+ ClusterState clusterState = clusterState(String.format("version:%d distributor:8 storage:8 .0.s:m .1.s:m .2.s:m", currentClusterStateVersion));
+ int nodeIndex = 4;
+ Node node = new Node(STORAGE, nodeIndex);
+ Result result = nodeStateChangeChecker.evaluateTransition(node, clusterState, SAFE, UP_NODE_STATE, MAINTENANCE_NODE_STATE);
+ assertFalse(result.settingWantedStateIsAllowed(), result.toString());
+ assertFalse(result.wantedStateAlreadySet());
+ assertEquals("At most 2 groups can have wanted state: [0, 1]", result.getReason());
+ }
+
+ // 2 nodes in group 0 and 1 in group 1 in maintenance, try to set storage node 3 in group 1 to maintenance
+ {
+ ClusterState clusterState = clusterState(String.format("version:%d distributor:8 storage:8 .0.s:m .1.s:m .2.s:m", currentClusterStateVersion));
+ int nodeIndex = 3;
+ checkSettingToMaintenanceIsAllowed(nodeIndex, nodeStateChangeChecker, clusterState);
+ setStorageNodeWantedStateToMaintenance(cluster, nodeIndex);
+ }
+
+ // 2 nodes in group 0 in maintenance, storage node 3 in group 1 is in maintenance with another description
+ // (set in maintenance by operator), try to set storage node 3 in group 1 to maintenance, should bew allowed
+ {
+ ClusterState clusterState = clusterState(String.format("version:%d distributor:8 storage:8 .0.s:m .1.s:m .3.s:m", currentClusterStateVersion));
+ setStorageNodeWantedState(cluster, 3, MAINTENANCE, "Maintenance, set by operator"); // Set to another description
+ setStorageNodeWantedState(cluster, 2, UP, ""); // Set back to UP, want to set this to maintenance again
+ int nodeIndex = 2;
+ Node node = new Node(STORAGE, nodeIndex);
+ Result result = nodeStateChangeChecker.evaluateTransition(node, clusterState, SAFE, UP_NODE_STATE, MAINTENANCE_NODE_STATE);
+ assertTrue(result.settingWantedStateIsAllowed(), result.toString());
+ assertFalse(result.wantedStateAlreadySet());
+ }
+
+ }
+
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testSafeMaintenanceDisallowedWhenOtherDistributorInFlatClusterIsSuspended(int maxNumberOfGroupsAllowedToBeDown) {
// Nodes 0-3, distributor 0 being down with "Orchestrator" description.
- ContentCluster cluster = createCluster(4);
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
setDistributorNodeWantedState(cluster, 0, DOWN, "Orchestrator");
var nodeStateChangeChecker = createChangeChecker(cluster);
ClusterState clusterStateWith0InMaintenance = clusterState(String.format(
@@ -186,11 +306,12 @@ public class NodeStateChangeCheckerTest {
result.getReason());
}
- @Test
- void testSafeMaintenanceDisallowedWhenDistributorInGroupIsDown() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testSafeMaintenanceDisallowedWhenDistributorInGroupIsDown(int maxNumberOfGroupsAllowedToBeDown) {
// Nodes 0-3, distributor 0 being in maintenance with "Orchestrator" description.
// 2 groups: nodes 0-1 is group 0, 2-3 is group 1.
- ContentCluster cluster = createCluster(4, 2);
+ ContentCluster cluster = createCluster(4, 2, maxNumberOfGroupsAllowedToBeDown);
setDistributorNodeWantedState(cluster, 0, DOWN, "Orchestrator");
var nodeStateChangeChecker = new NodeStateChangeChecker(cluster, false);
ClusterState clusterStateWith0InMaintenance = clusterState(String.format(
@@ -204,7 +325,10 @@ public class NodeStateChangeCheckerTest {
SAFE, UP_NODE_STATE, MAINTENANCE_NODE_STATE);
assertFalse(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
- assertEquals("At most one group can have wanted state: Other distributor 0 in group 0 has wanted state Down", result.getReason());
+ if (maxNumberOfGroupsAllowedToBeDown >= 1)
+ assertEquals("Wanted state already set for another node in groups: [0]", result.getReason());
+ else
+ assertEquals("At most one group can have wanted state: Other distributor 0 in group 0 has wanted state Down", result.getReason());
}
{
@@ -213,16 +337,22 @@ public class NodeStateChangeCheckerTest {
Result result = nodeStateChangeChecker.evaluateTransition(
new Node(STORAGE, 1), clusterStateWith0InMaintenance,
SAFE, UP_NODE_STATE, MAINTENANCE_NODE_STATE);
- assertFalse(result.settingWantedStateIsAllowed(), result.getReason());
- assertEquals("Another distributor wants state DOWN: 0", result.getReason());
+ if (maxNumberOfGroupsAllowedToBeDown >= 1) {
+ assertFalse(result.settingWantedStateIsAllowed(), result.getReason());
+ assertEquals("Wanted state already set for another node in groups: [0]", result.getReason());
+ } else {
+ assertFalse(result.settingWantedStateIsAllowed(), result.getReason());
+ assertEquals("Another distributor wants state DOWN: 0", result.getReason());
+ }
}
}
- @Test
- void testSafeMaintenanceWhenOtherStorageNodeInGroupIsSuspended() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testSafeMaintenanceWhenOtherStorageNodeInGroupIsSuspended(int maxNumberOfGroupsAllowedToBeDown) {
// Nodes 0-3, storage node 0 being in maintenance with "Orchestrator" description.
// 2 groups: nodes 0-1 is group 0, 2-3 is group 1.
- ContentCluster cluster = createCluster(4, 2);
+ ContentCluster cluster = createCluster(4, 2, maxNumberOfGroupsAllowedToBeDown);
setStorageNodeWantedState(cluster, 0, MAINTENANCE, "Orchestrator");
var nodeStateChangeChecker = new NodeStateChangeChecker(cluster, false);
ClusterState clusterStateWith0InMaintenance = clusterState(String.format(
@@ -236,8 +366,11 @@ public class NodeStateChangeCheckerTest {
SAFE, UP_NODE_STATE, MAINTENANCE_NODE_STATE);
assertFalse(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
- assertEquals("At most one group can have wanted state: Other storage node 0 in group 0 has wanted state Maintenance",
- result.getReason());
+ if (maxNumberOfGroupsAllowedToBeDown >= 1)
+ assertEquals("At most 1 groups can have wanted state: [0]", result.getReason());
+ else
+ assertEquals("At most one group can have wanted state: Other storage node 0 in group 0 has wanted state Maintenance",
+ result.getReason());
}
{
@@ -251,9 +384,10 @@ public class NodeStateChangeCheckerTest {
}
}
- @Test
- void testSafeSetStateDistributors() {
- NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(createCluster(1));
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testSafeSetStateDistributors(int maxNumberOfGroupsAllowedToBeDown) {
+ NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(createCluster(1, 1, maxNumberOfGroupsAllowedToBeDown));
Result result = nodeStateChangeChecker.evaluateTransition(
nodeDistributor, defaultAllUpClusterState(), SAFE,
UP_NODE_STATE, MAINTENANCE_NODE_STATE);
@@ -262,10 +396,11 @@ public class NodeStateChangeCheckerTest {
assertTrue(result.getReason().contains("Safe-set of node state is only supported for storage nodes"));
}
- @Test
- void testCanUpgradeSafeMissingStorage() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCanUpgradeSafeMissingStorage(int maxNumberOfGroupsAllowedToBeDown) {
// Create a content cluster with 4 nodes, and storage node with index 3 down.
- ContentCluster cluster = createCluster(4);
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
setAllNodesUp(cluster, HostInfo.createHostInfo(createDistributorHostInfo(4, 5, 6)));
cluster.clusterInfo().getStorageNodeInfo(3).setReportedState(new NodeState(STORAGE, DOWN), 0);
ClusterState clusterStateWith3Down = clusterState(String.format(
@@ -282,16 +417,18 @@ public class NodeStateChangeCheckerTest {
assertEquals("Another storage node has state DOWN: 3", result.getReason());
}
- @Test
- void testCanUpgradeStorageSafeYes() {
- Result result = transitionToMaintenanceWithNoStorageNodesDown(createCluster(4), defaultAllUpClusterState());
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCanUpgradeStorageSafeYes(int maxNumberOfGroupsAllowedToBeDown) {
+ Result result = transitionToMaintenanceWithNoStorageNodesDown(createCluster(4, 1, maxNumberOfGroupsAllowedToBeDown), defaultAllUpClusterState());
assertTrue(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
}
- @Test
- void testSetUpFailsIfReportedIsDown() {
- ContentCluster cluster = createCluster(4);
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testSetUpFailsIfReportedIsDown(int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
// Not setting nodes up -> all are down
@@ -304,9 +441,10 @@ public class NodeStateChangeCheckerTest {
// A node may be reported as Up but have a generated state of Down if it's part of
// nodes taken down implicitly due to a group having too low node availability.
- @Test
- void testSetUpSucceedsIfReportedIsUpButGeneratedIsDown() {
- ContentCluster cluster = createCluster(4);
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testSetUpSucceedsIfReportedIsUpButGeneratedIsDown(int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
markAllNodesAsReportingStateUp(cluster);
@@ -322,9 +460,10 @@ public class NodeStateChangeCheckerTest {
assertFalse(result.wantedStateAlreadySet());
}
- @Test
- void testCanSetUpEvenIfOldWantedStateIsDown() {
- ContentCluster cluster = createCluster(4);
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCanSetUpEvenIfOldWantedStateIsDown(int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
setAllNodesUp(cluster, HostInfo.createHostInfo(createDistributorHostInfo(4, 3, 6)));
@@ -335,9 +474,10 @@ public class NodeStateChangeCheckerTest {
assertFalse(result.wantedStateAlreadySet());
}
- @Test
- void testCanUpgradeStorageSafeNo() {
- ContentCluster cluster = createCluster(4);
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCanUpgradeStorageSafeNo(int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
setAllNodesUp(cluster, HostInfo.createHostInfo(createDistributorHostInfo(4, 3, 6)));
@@ -350,9 +490,10 @@ public class NodeStateChangeCheckerTest {
result.getReason());
}
- @Test
- void testCanUpgradeIfMissingMinReplicationFactor() {
- ContentCluster cluster = createCluster(4);
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCanUpgradeIfMissingMinReplicationFactor(int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
setAllNodesUp(cluster, HostInfo.createHostInfo(createDistributorHostInfo(4, 3, 6)));
@@ -363,9 +504,10 @@ public class NodeStateChangeCheckerTest {
assertFalse(result.wantedStateAlreadySet());
}
- @Test
- void testCanUpgradeIfStorageNodeMissingFromNodeInfo() {
- ContentCluster cluster = createCluster(4);
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCanUpgradeIfStorageNodeMissingFromNodeInfo(int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
String hostInfo = "{\n" +
" \"cluster-state-version\": 2,\n" +
@@ -387,9 +529,10 @@ public class NodeStateChangeCheckerTest {
assertFalse(result.wantedStateAlreadySet());
}
- @Test
- void testMissingDistributorState() {
- ContentCluster cluster = createCluster(4);
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testMissingDistributorState(int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
cluster.clusterInfo().getStorageNodeInfo(1).setReportedState(new NodeState(STORAGE, UP), 0);
@@ -400,8 +543,8 @@ public class NodeStateChangeCheckerTest {
assertEquals("Distributor node 0 has not reported any cluster state version yet.", result.getReason());
}
- private Result transitionToSameState(State state, String oldDescription, String newDescription) {
- ContentCluster cluster = createCluster(4);
+ private Result transitionToSameState(State state, String oldDescription, String newDescription, int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
NodeState currentNodeState = createNodeState(state, oldDescription);
@@ -411,26 +554,29 @@ public class NodeStateChangeCheckerTest {
currentNodeState, newNodeState);
}
- private Result transitionToSameState(String oldDescription, String newDescription) {
- return transitionToSameState(MAINTENANCE, oldDescription, newDescription);
+ private Result transitionToSameState(String oldDescription, String newDescription, int maxNumberOfGroupsAllowedToBeDown) {
+ return transitionToSameState(MAINTENANCE, oldDescription, newDescription, maxNumberOfGroupsAllowedToBeDown);
}
- @Test
- void testSettingUpWhenUpCausesAlreadySet() {
- Result result = transitionToSameState(UP, "foo", "bar");
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testSettingUpWhenUpCausesAlreadySet(int maxNumberOfGroupsAllowedToBeDown) {
+ Result result = transitionToSameState(UP, "foo", "bar", maxNumberOfGroupsAllowedToBeDown);
assertTrue(result.wantedStateAlreadySet());
}
- @Test
- void testSettingAlreadySetState() {
- Result result = transitionToSameState("foo", "foo");
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testSettingAlreadySetState(int maxNumberOfGroupsAllowedToBeDown) {
+ Result result = transitionToSameState("foo", "foo", maxNumberOfGroupsAllowedToBeDown);
assertFalse(result.settingWantedStateIsAllowed());
assertTrue(result.wantedStateAlreadySet());
}
- @Test
- void testDifferentDescriptionImpliesDenied() {
- Result result = transitionToSameState("foo", "bar");
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testDifferentDescriptionImpliesDenied(int maxNumberOfGroupsAllowedToBeDown) {
+ Result result = transitionToSameState("foo", "bar", maxNumberOfGroupsAllowedToBeDown);
assertFalse(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
}
@@ -439,10 +585,9 @@ public class NodeStateChangeCheckerTest {
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
for (int x = 0; x < cluster.clusterInfo().getConfiguredNodes().size(); x++) {
- State state = UP;
- cluster.clusterInfo().getDistributorNodeInfo(x).setReportedState(new NodeState(DISTRIBUTOR, state), 0);
+ cluster.clusterInfo().getDistributorNodeInfo(x).setReportedState(new NodeState(DISTRIBUTOR, UP), 0);
cluster.clusterInfo().getDistributorNodeInfo(x).setHostInfo(HostInfo.createHostInfo(createDistributorHostInfo(4, 5, 6)));
- cluster.clusterInfo().getStorageNodeInfo(x).setReportedState(new NodeState(STORAGE, state), 0);
+ cluster.clusterInfo().getStorageNodeInfo(x).setReportedState(new NodeState(STORAGE, UP), 0);
}
return nodeStateChangeChecker.evaluateTransition(
@@ -462,26 +607,29 @@ public class NodeStateChangeCheckerTest {
return transitionToMaintenanceWithOneStorageNodeDown(cluster, clusterState);
}
- @Test
- void testCanUpgradeWhenAllUp() {
- Result result = transitionToMaintenanceWithNoStorageNodesDown(createCluster(4), defaultAllUpClusterState());
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCanUpgradeWhenAllUp(int maxNumberOfGroupsAllowedToBeDown) {
+ Result result = transitionToMaintenanceWithNoStorageNodesDown(createCluster(4, maxNumberOfGroupsAllowedToBeDown), defaultAllUpClusterState());
assertTrue(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
}
- @Test
- void testCanUpgradeWhenAllUpOrRetired() {
- Result result = transitionToMaintenanceWithNoStorageNodesDown(createCluster(4), defaultAllUpClusterState());
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCanUpgradeWhenAllUpOrRetired(int maxNumberOfGroupsAllowedToBeDown) {
+ Result result = transitionToMaintenanceWithNoStorageNodesDown(createCluster(4, maxNumberOfGroupsAllowedToBeDown), defaultAllUpClusterState());
assertTrue(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
}
- @Test
- void testCanUpgradeWhenStorageIsDown() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCanUpgradeWhenStorageIsDown(int maxNumberOfGroupsAllowedToBeDown) {
ClusterState clusterState = defaultAllUpClusterState();
var storageNodeIndex = nodeStorage.getIndex();
- ContentCluster cluster = createCluster(4);
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeState downNodeState = new NodeState(STORAGE, DOWN);
cluster.clusterInfo().getStorageNodeInfo(storageNodeIndex).setReportedState(downNodeState, 4 /* time */);
clusterState.setNodeState(new Node(STORAGE, storageNodeIndex), downNodeState);
@@ -491,13 +639,14 @@ public class NodeStateChangeCheckerTest {
assertFalse(result.wantedStateAlreadySet());
}
- @Test
- void testCannotUpgradeWhenOtherStorageIsDown() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testCannotUpgradeWhenOtherStorageIsDown(int maxNumberOfGroupsAllowedToBeDown) {
int otherIndex = 2;
// If this fails, just set otherIndex to some other valid index.
assertNotEquals(nodeStorage.getIndex(), otherIndex);
- ContentCluster cluster = createCluster(4);
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
ClusterState clusterState = defaultAllUpClusterState();
NodeState downNodeState = new NodeState(STORAGE, DOWN);
cluster.clusterInfo().getStorageNodeInfo(otherIndex).setReportedState(downNodeState, 4 /* time */);
@@ -509,9 +658,10 @@ public class NodeStateChangeCheckerTest {
assertTrue(result.getReason().contains("Another storage node has state DOWN: 2"));
}
- @Test
- void testNodeRatioRequirementConsidersGeneratedNodeStates() {
- ContentCluster cluster = createCluster(4);
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testNodeRatioRequirementConsidersGeneratedNodeStates(int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
markAllNodesAsReportingStateUp(cluster);
@@ -531,62 +681,72 @@ public class NodeStateChangeCheckerTest {
assertFalse(result.wantedStateAlreadySet());
}
- @Test
- void testDownDisallowedByNonRetiredState() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testDownDisallowedByNonRetiredState(int maxNumberOfGroupsAllowedToBeDown) {
Result result = evaluateDownTransition(
defaultAllUpClusterState(),
UP,
currentClusterStateVersion,
- 0);
+ 0,
+ maxNumberOfGroupsAllowedToBeDown);
assertFalse(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
assertEquals("Only retired nodes are allowed to be set to DOWN in safe mode - is Up", result.getReason());
}
- @Test
- void testDownDisallowedByBuckets() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testDownDisallowedByBuckets(int maxNumberOfGroupsAllowedToBeDown) {
Result result = evaluateDownTransition(
retiredClusterStateSuffix(),
UP,
currentClusterStateVersion,
- 1);
+ 1,
+ maxNumberOfGroupsAllowedToBeDown);
assertFalse(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
assertEquals("The storage node manages 1 buckets", result.getReason());
}
- @Test
- void testDownDisallowedByReportedState() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testDownDisallowedByReportedState(int maxNumberOfGroupsAllowedToBeDown) {
Result result = evaluateDownTransition(
retiredClusterStateSuffix(),
INITIALIZING,
currentClusterStateVersion,
- 0);
+ 0,
+ maxNumberOfGroupsAllowedToBeDown);
assertFalse(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
assertEquals("Reported state (Initializing) is not UP, so no bucket data is available", result.getReason());
}
- @Test
- void testDownDisallowedByVersionMismatch() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testDownDisallowedByVersionMismatch(int maxNumberOfGroupsAllowedToBeDown) {
Result result = evaluateDownTransition(
retiredClusterStateSuffix(),
UP,
currentClusterStateVersion - 1,
- 0);
+ 0,
+ maxNumberOfGroupsAllowedToBeDown);
assertFalse(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
assertEquals("Cluster controller at version 2 got info for storage node 1 at a different version 1",
result.getReason());
}
- @Test
- void testAllowedToSetDown() {
+ @ParameterizedTest
+ @ValueSource(ints = {-1, 1})
+ void testAllowedToSetDown(int maxNumberOfGroupsAllowedToBeDown) {
Result result = evaluateDownTransition(
retiredClusterStateSuffix(),
UP,
currentClusterStateVersion,
- 0);
+ 0,
+ maxNumberOfGroupsAllowedToBeDown);
assertTrue(result.settingWantedStateIsAllowed());
assertFalse(result.wantedStateAlreadySet());
}
@@ -594,8 +754,9 @@ public class NodeStateChangeCheckerTest {
private Result evaluateDownTransition(ClusterState clusterState,
State reportedState,
int hostInfoClusterStateVersion,
- int lastAlldisksBuckets) {
- ContentCluster cluster = createCluster(4);
+ int lastAlldisksBuckets,
+ int maxNumberOfGroupsAllowedToBeDown) {
+ ContentCluster cluster = createCluster(4, maxNumberOfGroupsAllowedToBeDown);
NodeStateChangeChecker nodeStateChangeChecker = createChangeChecker(cluster);
StorageNodeInfo nodeInfo = cluster.clusterInfo().getStorageNodeInfo(nodeStorage.getIndex());
@@ -763,6 +924,18 @@ public class NodeStateChangeCheckerTest {
return configBuilder.build();
}
+ private void checkSettingToMaintenanceIsAllowed(int nodeIndex, NodeStateChangeChecker nodeStateChangeChecker, ClusterState clusterState) {
+ Node node = new Node(STORAGE, nodeIndex);
+ Result result = nodeStateChangeChecker.evaluateTransition(node, clusterState, SAFE, UP_NODE_STATE, MAINTENANCE_NODE_STATE);
+ assertTrue(result.settingWantedStateIsAllowed(), result.toString());
+ assertFalse(result.wantedStateAlreadySet());
+ assertEquals("Preconditions fulfilled and new state different", result.getReason());
+ }
+
+ private void setStorageNodeWantedStateToMaintenance(ContentCluster cluster, int nodeIndex) {
+ setStorageNodeWantedState(cluster, nodeIndex, MAINTENANCE, "Orchestrator");
+ }
+
private void setStorageNodeWantedState(ContentCluster cluster, int nodeIndex, State state, String description) {
NodeState nodeState = new NodeState(STORAGE, state);
cluster.clusterInfo().getStorageNodeInfo(nodeIndex).setWantedState(nodeState.setDescription(description));
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
index fef2354c452..7f2dd4b6acd 100644
--- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
@@ -113,6 +113,7 @@ public interface ModelContext {
@ModelFeatureFlag(owners = {"tokle, bjorncs"}, removeAfter = "8.108") default boolean enableDataPlaneFilter() { return true; }
@ModelFeatureFlag(owners = {"arnej, bjorncs"}) default boolean enableGlobalPhase() { return true; }
@ModelFeatureFlag(owners = {"baldersheim"}, comment = "Select summary decode type") default String summaryDecodePolicy() { return "eager"; }
+ @ModelFeatureFlag(owners = {"hmusum"}) default boolean allowMoreThanOneContentGroupDown(ClusterSpec.Id id) { return false; }
//Below are all flags that must be kept until 7 is out of the door
@ModelFeatureFlag(owners = {"arnej"}, removeAfter="7.last") default boolean ignoreThreadStackSizes() { return false; }
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/ClusterControllerConfig.java b/config-model/src/main/java/com/yahoo/vespa/model/content/ClusterControllerConfig.java
index 8ec4ae35658..201e0b5693a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/ClusterControllerConfig.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/ClusterControllerConfig.java
@@ -14,8 +14,6 @@ import org.w3c.dom.Element;
/**
* Config generation for common parameters for all fleet controllers.
- *
- * TODO: Author
*/
public class ClusterControllerConfig extends AnyConfigProducer implements FleetcontrollerConfig.Producer {
@@ -23,11 +21,16 @@ public class ClusterControllerConfig extends AnyConfigProducer implements Fleetc
private final String clusterName;
private final ModelElement clusterElement;
private final ResourceLimits resourceLimits;
+ private final boolean allowMoreThanOneContentGroupDown;
- public Builder(String clusterName, ModelElement clusterElement, ResourceLimits resourceLimits) {
+ public Builder(String clusterName,
+ ModelElement clusterElement,
+ ResourceLimits resourceLimits,
+ boolean allowMoreThanOneContentGroupDown) {
this.clusterName = clusterName;
this.clusterElement = clusterElement;
this.resourceLimits = resourceLimits;
+ this.allowMoreThanOneContentGroupDown = allowMoreThanOneContentGroupDown;
}
@Override
@@ -53,13 +56,15 @@ public class ClusterControllerConfig extends AnyConfigProducer implements Fleetc
tuning.childAsDouble("min-storage-up-ratio"),
bucketSplittingMinimumBits,
minNodeRatioPerGroup,
- resourceLimits);
+ resourceLimits,
+ allowMoreThanOneContentGroupDown);
} else {
return new ClusterControllerConfig(ancestor, clusterName,
null, null, null, null, null, null,
bucketSplittingMinimumBits,
minNodeRatioPerGroup,
- resourceLimits);
+ resourceLimits,
+ allowMoreThanOneContentGroupDown);
}
}
}
@@ -74,6 +79,7 @@ public class ClusterControllerConfig extends AnyConfigProducer implements Fleetc
private final Integer minSplitBits;
private final Double minNodeRatioPerGroup;
private final ResourceLimits resourceLimits;
+ private final boolean allowMoreThanOneContentGroupDown;
// TODO refactor; too many args
private ClusterControllerConfig(TreeConfigProducer<?> parent,
@@ -86,7 +92,8 @@ public class ClusterControllerConfig extends AnyConfigProducer implements Fleetc
Double minStorageUpRatio,
Integer minSplitBits,
Double minNodeRatioPerGroup,
- ResourceLimits resourceLimits) {
+ ResourceLimits resourceLimits,
+ boolean allowMoreThanOneContentGroupDown) {
super(parent, "fleetcontroller");
this.clusterName = clusterName;
@@ -99,6 +106,7 @@ public class ClusterControllerConfig extends AnyConfigProducer implements Fleetc
this.minSplitBits = minSplitBits;
this.minNodeRatioPerGroup = minNodeRatioPerGroup;
this.resourceLimits = resourceLimits;
+ this.allowMoreThanOneContentGroupDown = allowMoreThanOneContentGroupDown;
}
@Override
@@ -141,6 +149,7 @@ public class ClusterControllerConfig extends AnyConfigProducer implements Fleetc
builder.min_node_ratio_per_group(minNodeRatioPerGroup);
}
resourceLimits.getConfig(builder);
+ builder.max_number_of_groups_allowed_to_be_down(allowMoreThanOneContentGroupDown ? 1 : -1);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java
index 7f4fc4cd89d..217c26516a9 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java
@@ -127,7 +127,8 @@ public class ContentCluster extends TreeConfigProducer<AnyConfigProducer> implem
.build(contentElement);
c.clusterControllerConfig = new ClusterControllerConfig.Builder(clusterId,
contentElement,
- resourceLimits.getClusterControllerLimits())
+ resourceLimits.getClusterControllerLimits(),
+ deployState.featureFlags().allowMoreThanOneContentGroupDown(new ClusterSpec.Id(clusterId)))
.build(deployState, c, contentElement.getXml());
c.search = new ContentSearchCluster.Builder(documentDefinitions,
globallyDistributedDocuments,
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/NodeResourcesTuning.java b/config-model/src/main/java/com/yahoo/vespa/model/search/NodeResourcesTuning.java
index ee18eceb719..2de06e2053a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/NodeResourcesTuning.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/NodeResourcesTuning.java
@@ -18,7 +18,6 @@ public class NodeResourcesTuning implements ProtonConfig.Producer {
private final static double SUMMARY_CACHE_SIZE_AS_FRACTION_OF_MEMORY = 0.04;
private final static double MEMORY_GAIN_AS_FRACTION_OF_MEMORY = 0.08;
private final static double MIN_MEMORY_PER_FLUSH_THREAD_GB = 16.0;
- private final static double MAX_FLUSH_THREAD_RATIO = 1.0/8;
private final static double TLS_SIZE_FRACTION = 0.02;
final static long MB = 1024 * 1024;
public final static long GB = MB * 1024;
@@ -94,13 +93,12 @@ public class NodeResourcesTuning implements ProtonConfig.Producer {
}
private void tuneFlushConcurrentThreads(ProtonConfig.Flush.Builder builder) {
+ int max_concurrent = 2; // TODO bring slowly up towards 4
if (usableMemoryGb() < MIN_MEMORY_PER_FLUSH_THREAD_GB) {
- builder.maxconcurrent(1);
+ max_concurrent = 1;
}
- double min_concurrent_mem = usableMemoryGb() / (2*MIN_MEMORY_PER_FLUSH_THREAD_GB);
- double min_concurrent_cpu = resources.vcpu() * MAX_FLUSH_THREAD_RATIO;
- builder.maxconcurrent(Math.min(builder.build().maxconcurrent(),
- (int)Math.ceil(Math.max(min_concurrent_mem, min_concurrent_cpu))));
+ double min_concurrent_mem = usableMemoryGb() / MIN_MEMORY_PER_FLUSH_THREAD_GB;
+ builder.maxconcurrent(Math.min(max_concurrent, (int)Math.ceil(min_concurrent_mem)));
}
private void tuneFlushStrategyTlsSize(ProtonConfig.Flush.Memory.Builder builder) {
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/FleetControllerClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/FleetControllerClusterTest.java
index 1e6847a47be..1f8dea41a3e 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/content/FleetControllerClusterTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/content/FleetControllerClusterTest.java
@@ -27,7 +27,8 @@ public class FleetControllerClusterTest {
new ClusterResourceLimits.Builder(false,
featureFlags.resourceLimitDisk(),
featureFlags.resourceLimitMemory())
- .build(clusterElement).getClusterControllerLimits())
+ .build(clusterElement).getClusterControllerLimits(),
+ false)
.build(root.getDeployState(), root, clusterElement.getXml());
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/search/NodeResourcesTuningTest.java b/config-model/src/test/java/com/yahoo/vespa/model/search/NodeResourcesTuningTest.java
index 9fe38512fc0..d344be3da9a 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/search/NodeResourcesTuningTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/search/NodeResourcesTuningTest.java
@@ -182,14 +182,12 @@ public class NodeResourcesTuningTest {
}
@Test
public void require_that_concurrent_flush_threads_is_1_with_low_memory() {
- assertEquals(2, fromMemAndCpu(17, 9).flush().maxconcurrent());
- assertEquals(2, fromMemAndCpu(17, 64).flush().maxconcurrent()); // still capped by max
- assertEquals(2, fromMemAndCpu(65, 8).flush().maxconcurrent()); // still capped by max
- assertEquals(2, fromMemAndCpu(33, 8).flush().maxconcurrent());
- assertEquals(1, fromMemAndCpu(31, 8).flush().maxconcurrent());
- assertEquals(1, fromMemAndCpu(15, 8).flush().maxconcurrent());
- assertEquals(1, fromMemAndCpu(17, 8).flush().maxconcurrent());
+ assertEquals(1, fromMemAndCpu(1, 8).flush().maxconcurrent());
assertEquals(1, fromMemAndCpu(15, 8).flush().maxconcurrent());
+ assertEquals(1, fromMemAndCpu(16, 8).flush().maxconcurrent());
+ assertEquals(2, fromMemAndCpu(17, 8).flush().maxconcurrent());
+ assertEquals(2, fromMemAndCpu(65, 8).flush().maxconcurrent()); // still capped by max
+ assertEquals(2, fromMemAndCpu(65, 65).flush().maxconcurrent()); // still capped by max
}
private static void assertDocumentStoreMaxFileSize(long expFileSizeBytes, int wantedMemoryGb) {
diff --git a/configdefinitions/src/vespa/fleetcontroller.def b/configdefinitions/src/vespa/fleetcontroller.def
index f95dada28d9..93a20e4ee0d 100644
--- a/configdefinitions/src/vespa/fleetcontroller.def
+++ b/configdefinitions/src/vespa/fleetcontroller.def
@@ -64,11 +64,6 @@ init_progress_time int default=0
## we dont change the state too often.
min_time_between_new_systemstates int default=10000
-## Sets how many milliseconds to wait between each state poll for old nodes
-## requiring state polling. (4.1 or older)
-## TODO: Not used, remove in Vespa 9
-state_polling_frequency int default=5000
-
## The maximum amount of premature crashes a node is allowed to have in a row
## before the fleetcontroller disables that node.
max_premature_crashes int default=100000
@@ -181,9 +176,6 @@ min_merge_completion_ratio double default=1.0
## transition logic aims to minimize the window of time where active states diverge.
enable_two_phase_cluster_state_transitions bool default=false
-## TODO: Deprecated - not used, remove in Vespa 9
-determine_buckets_from_bucket_space_metric bool default=true
-
# If enabled, the cluster controller observes reported (categorized) resource usage from content nodes (via host info),
# and decides whether external feed should be blocked (or unblocked) in the entire cluster.
#
@@ -207,6 +199,7 @@ cluster_feed_block_limit{} double
# This is in absolute numbers, so 0.01 implies that a block limit of 0.8 effectively
# becomes 0.79 for an already blocked node.
cluster_feed_block_noise_level double default=0.0
-# For apps that have several groups this controls how many are allowed to be down
-# simultaneously.
-max_number_of_groups_allowed_to_be_down int default=1
+
+# For apps that have several groups this controls how many groups are allowed to
+# be down simultaneously in this cluster.
+max_number_of_groups_allowed_to_be_down int default=-1
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
index 7a2377594a1..62431ce4c06 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
@@ -22,10 +22,10 @@ import com.yahoo.config.provision.AthenzDomain;
import com.yahoo.config.provision.CloudAccount;
import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.DockerImage;
+import com.yahoo.config.provision.HostName;
import com.yahoo.config.provision.TenantName;
import com.yahoo.config.provision.Zone;
import com.yahoo.container.jdisc.secretstore.SecretStore;
-import com.yahoo.config.provision.HostName;
import com.yahoo.vespa.config.server.tenant.SecretStoreExternalIdRetriever;
import com.yahoo.vespa.flags.FetchVector;
import com.yahoo.vespa.flags.FlagSource;
@@ -33,7 +33,6 @@ import com.yahoo.vespa.flags.Flags;
import com.yahoo.vespa.flags.PermanentFlags;
import com.yahoo.vespa.flags.StringFlag;
import com.yahoo.vespa.flags.UnboundFlag;
-
import java.io.File;
import java.net.URI;
import java.security.cert.X509Certificate;
@@ -41,7 +40,7 @@ import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
-import java.util.function.ToIntFunction;
+import java.util.function.Predicate;
import static com.yahoo.config.provision.NodeResources.Architecture;
import static com.yahoo.vespa.config.server.ConfigServerSpec.fromConfig;
@@ -174,7 +173,7 @@ public class ModelContextImpl implements ModelContext {
private final double feedNiceness;
private final List<String> allowedAthenzProxyIdentities;
private final int maxActivationInhibitedOutOfSyncGroups;
- private final ToIntFunction<ClusterSpec.Type> jvmOmitStackTraceInFastThrow;
+ private final Predicate<ClusterSpec.Type> jvmOmitStackTraceInFastThrow;
private final double resourceLimitDisk;
private final double resourceLimitMemory;
private final double minNodeRatioPerGroup;
@@ -204,6 +203,7 @@ public class ModelContextImpl implements ModelContext {
private final int heapPercentage;
private final boolean enableGlobalPhase;
private final String summaryDecodePolicy;
+ private final Predicate<ClusterSpec.Id> allowMoreThanOneContentGroupDown;
public FeatureFlags(FlagSource source, ApplicationId appId, Version version) {
this.defaultTermwiseLimit = flagValue(source, appId, version, Flags.DEFAULT_TERM_WISE_LIMIT);
@@ -219,7 +219,7 @@ public class ModelContextImpl implements ModelContext {
this.mbus_network_threads = flagValue(source, appId, version, Flags.MBUS_NUM_NETWORK_THREADS);
this.allowedAthenzProxyIdentities = flagValue(source, appId, version, Flags.ALLOWED_ATHENZ_PROXY_IDENTITIES);
this.maxActivationInhibitedOutOfSyncGroups = flagValue(source, appId, version, Flags.MAX_ACTIVATION_INHIBITED_OUT_OF_SYNC_GROUPS);
- this.jvmOmitStackTraceInFastThrow = type -> flagValueAsInt(source, appId, version, type, PermanentFlags.JVM_OMIT_STACK_TRACE_IN_FAST_THROW);
+ this.jvmOmitStackTraceInFastThrow = type -> flagValue(source, appId, version, type, PermanentFlags.JVM_OMIT_STACK_TRACE_IN_FAST_THROW);
this.resourceLimitDisk = flagValue(source, appId, version, PermanentFlags.RESOURCE_LIMIT_DISK);
this.resourceLimitMemory = flagValue(source, appId, version, PermanentFlags.RESOURCE_LIMIT_MEMORY);
this.minNodeRatioPerGroup = flagValue(source, appId, version, Flags.MIN_NODE_RATIO_PER_GROUP);
@@ -250,6 +250,7 @@ public class ModelContextImpl implements ModelContext {
this.heapPercentage = flagValue(source, appId, version, PermanentFlags.HEAP_SIZE_PERCENTAGE);
this.enableGlobalPhase = flagValue(source, appId, version, Flags.ENABLE_GLOBAL_PHASE);
this.summaryDecodePolicy = flagValue(source, appId, version, Flags.SUMMARY_DECODE_POLICY);
+ this.allowMoreThanOneContentGroupDown = clusterId -> flagValue(source, appId, version, clusterId, Flags.ALLOW_MORE_THAN_ONE_CONTENT_GROUP_DOWN);
}
@Override public int heapSizePercentage() { return heapPercentage; }
@@ -270,7 +271,7 @@ public class ModelContextImpl implements ModelContext {
@Override public List<String> allowedAthenzProxyIdentities() { return allowedAthenzProxyIdentities; }
@Override public int maxActivationInhibitedOutOfSyncGroups() { return maxActivationInhibitedOutOfSyncGroups; }
@Override public String jvmOmitStackTraceInFastThrowOption(ClusterSpec.Type type) {
- return translateJvmOmitStackTraceInFastThrowIntToString(jvmOmitStackTraceInFastThrow, type);
+ return translateJvmOmitStackTraceInFastThrowToString(jvmOmitStackTraceInFastThrow, type);
}
@Override public double resourceLimitDisk() { return resourceLimitDisk; }
@Override public double resourceLimitMemory() { return resourceLimitMemory; }
@@ -304,6 +305,7 @@ public class ModelContextImpl implements ModelContext {
}
@Override public boolean useRestrictedDataPlaneBindings() { return useRestrictedDataPlaneBindings; }
@Override public boolean enableGlobalPhase() { return enableGlobalPhase; }
+ @Override public boolean allowMoreThanOneContentGroupDown(ClusterSpec.Id id) { return allowMoreThanOneContentGroupDown.test(id); }
private static <V> V flagValue(FlagSource source, ApplicationId appId, Version vespaVersion, UnboundFlag<? extends V, ?, ?> flag) {
return flag.bindTo(source)
@@ -331,17 +333,21 @@ public class ModelContextImpl implements ModelContext {
.boxedValue();
}
- static int flagValueAsInt(FlagSource source,
- ApplicationId appId,
- Version version,
- ClusterSpec.Type clusterType,
- UnboundFlag<? extends Boolean, ?, ?> flag) {
- return flagValue(source, appId, version, clusterType, flag) ? 1 : 0;
+ private static <V> V flagValue(FlagSource source,
+ ApplicationId appId,
+ Version vespaVersion,
+ ClusterSpec.Id clusterId,
+ UnboundFlag<? extends V, ?, ?> flag) {
+ return flag.bindTo(source)
+ .with(FetchVector.Dimension.APPLICATION_ID, appId.serializedForm())
+ .with(FetchVector.Dimension.CLUSTER_ID, clusterId.value())
+ .with(FetchVector.Dimension.VESPA_VERSION, vespaVersion.toFullString())
+ .boxedValue();
}
- private String translateJvmOmitStackTraceInFastThrowIntToString(ToIntFunction<ClusterSpec.Type> function,
- ClusterSpec.Type clusterType) {
- return function.applyAsInt(clusterType) == 1 ? "" : "-XX:-OmitStackTraceInFastThrow";
+ private String translateJvmOmitStackTraceInFastThrowToString(Predicate<ClusterSpec.Type> function,
+ ClusterSpec.Type clusterType) {
+ return function.test(clusterType) ? "" : "-XX:-OmitStackTraceInFastThrow";
}
}
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 36531fbf5e1..84411b31274 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -209,8 +209,11 @@
],
"methods" : [
"public void <init>(byte[])",
+ "public byte[] value()",
"public int compareTo(com.yahoo.prelude.hitfield.RawBase64)",
"public java.lang.String toString()",
+ "public boolean equals(java.lang.Object)",
+ "public int hashCode()",
"public bridge synthetic int compareTo(java.lang.Object)"
],
"fields" : [ ]
diff --git a/container-search/src/main/java/com/yahoo/prelude/hitfield/RawBase64.java b/container-search/src/main/java/com/yahoo/prelude/hitfield/RawBase64.java
index 2071e43f54c..485e2c9a8c3 100644
--- a/container-search/src/main/java/com/yahoo/prelude/hitfield/RawBase64.java
+++ b/container-search/src/main/java/com/yahoo/prelude/hitfield/RawBase64.java
@@ -3,16 +3,22 @@ package com.yahoo.prelude.hitfield;
import java.util.Arrays;
import java.util.Base64;
+import java.util.Objects;
/**
+ * Wraps a byte [] and renders it as base64 encoded string
* @author baldersheim
*/
public class RawBase64 implements Comparable<RawBase64> {
+ private final static Base64.Encoder encoder = Base64.getEncoder().withoutPadding();
private final byte[] content;
public RawBase64(byte[] content) {
+ Objects.requireNonNull(content);
this.content = content;
}
+ public byte [] value() { return content; }
+
@Override
public int compareTo(RawBase64 rhs) {
return Arrays.compareUnsigned(content, rhs.content);
@@ -20,6 +26,19 @@ public class RawBase64 implements Comparable<RawBase64> {
@Override
public String toString() {
- return Base64.getEncoder().encodeToString(content);
+ return encoder.encodeToString(content);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ RawBase64 rawBase64 = (RawBase64) o;
+ return Arrays.equals(content, rawBase64.content);
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(content);
}
}
diff --git a/container-search/src/main/java/com/yahoo/prelude/query/MultiRangeItem.java b/container-search/src/main/java/com/yahoo/prelude/query/MultiRangeItem.java
index 7ba7a13936f..3dac4cb92c0 100644
--- a/container-search/src/main/java/com/yahoo/prelude/query/MultiRangeItem.java
+++ b/container-search/src/main/java/com/yahoo/prelude/query/MultiRangeItem.java
@@ -271,7 +271,7 @@ public class MultiRangeItem<Type extends Number> extends MultiTermItem {
if (endInclusive) metadata |= 0b00000100;
encoder = type.encoderFor(sortedRanges());
- metadata |= encoder.id << 3;
+ metadata |= (byte)(encoder.id << 3);
buffer.put(metadata);
putString(startIndex, buffer);
diff --git a/container-search/src/main/java/com/yahoo/prelude/query/MultiTermItem.java b/container-search/src/main/java/com/yahoo/prelude/query/MultiTermItem.java
index a7ca62d153c..03a661499e0 100644
--- a/container-search/src/main/java/com/yahoo/prelude/query/MultiTermItem.java
+++ b/container-search/src/main/java/com/yahoo/prelude/query/MultiTermItem.java
@@ -67,8 +67,8 @@ abstract class MultiTermItem extends SimpleTaggableItem {
super.encodeThis(buffer);
byte metadata = 0;
- metadata |= (operatorType().code << 5) & 0b11100000;
- metadata |= ( termType().code ) & 0b00011111;
+ metadata |= (byte)((byte)(operatorType().code << 5) & (byte)0b11100000);
+ metadata |= (byte)(termType().code & (byte)0b00011111);
buffer.put(metadata);
buffer.putInt(terms());
encodeBlueprint(buffer);
diff --git a/container-search/src/main/java/com/yahoo/search/grouping/result/BucketGroupId.java b/container-search/src/main/java/com/yahoo/search/grouping/result/BucketGroupId.java
index 05efc134465..0bea390ad63 100644
--- a/container-search/src/main/java/com/yahoo/search/grouping/result/BucketGroupId.java
+++ b/container-search/src/main/java/com/yahoo/search/grouping/result/BucketGroupId.java
@@ -1,8 +1,6 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.grouping.result;
-import static com.yahoo.text.Lowercase.toLowerCase;
-
/**
* This abstract class is used in {@link Group} instances where the identifying expression evaluated to a {@link
* com.yahoo.search.grouping.request.BucketValue}. The range is inclusive-from and exclusive-to.
diff --git a/container-search/src/main/java/com/yahoo/search/grouping/result/HitRenderer.java b/container-search/src/main/java/com/yahoo/search/grouping/result/HitRenderer.java
index 343fea82b6e..91c46960ab0 100644
--- a/container-search/src/main/java/com/yahoo/search/grouping/result/HitRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/grouping/result/HitRenderer.java
@@ -7,7 +7,6 @@ import com.yahoo.text.Utf8String;
import com.yahoo.text.XMLWriter;
import java.io.IOException;
-import java.util.Arrays;
import java.util.Map;
/**
@@ -63,28 +62,15 @@ public abstract class HitRenderer {
private static void renderGroupId(GroupId id, XMLWriter writer) {
writer.openTag(TAG_GROUP_ID).attribute(ATR_TYPE, id.getTypeName());
- if (id instanceof ValueGroupId) {
- writer.content(getIdValue((ValueGroupId)id), false);
- } else if (id instanceof BucketGroupId) {
- BucketGroupId bucketId = (BucketGroupId)id;
- writer.openTag(TAG_BUCKET_FROM).content(getBucketFrom(bucketId), false).closeTag();
- writer.openTag(TAG_BUCKET_TO).content(getBucketTo(bucketId), false).closeTag();
+ if (id instanceof ValueGroupId<?> valueGroupId) {
+ writer.content(valueGroupId.getValue(), false);
+ } else if (id instanceof BucketGroupId bucketId) {
+ writer.openTag(TAG_BUCKET_FROM).content(bucketId.getFrom(), false).closeTag();
+ writer.openTag(TAG_BUCKET_TO).content(bucketId.getTo(), false).closeTag();
}
writer.closeTag();
}
- private static Object getIdValue(ValueGroupId id) {
- return id instanceof RawId ? Arrays.toString(((RawId)id).getValue()) : id.getValue();
- }
-
- private static Object getBucketFrom(BucketGroupId id) {
- return id instanceof RawBucketId ? Arrays.toString(((RawBucketId)id).getFrom()) : id.getFrom();
- }
-
- private static Object getBucketTo(BucketGroupId id) {
- return id instanceof RawBucketId ? Arrays.toString(((RawBucketId)id).getTo()) : id.getTo();
- }
-
private static void renderContinuations(Map<String, Continuation> continuations, XMLWriter writer) {
for (Map.Entry<String, Continuation> entry : continuations.entrySet()) {
renderContinuation(entry.getKey(), entry.getValue(), writer);
diff --git a/container-search/src/main/java/com/yahoo/search/grouping/result/RawBucketId.java b/container-search/src/main/java/com/yahoo/search/grouping/result/RawBucketId.java
index 9576f548f4a..9b5ad6660b0 100644
--- a/container-search/src/main/java/com/yahoo/search/grouping/result/RawBucketId.java
+++ b/container-search/src/main/java/com/yahoo/search/grouping/result/RawBucketId.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.grouping.result;
-import java.util.Arrays;
+import com.yahoo.prelude.hitfield.RawBase64;
/**
* This class is used in {@link Group} instances where the identifying
@@ -9,7 +9,7 @@ import java.util.Arrays;
*
* @author Ulf Lilleengen
*/
-public class RawBucketId extends BucketGroupId<byte[]> {
+public class RawBucketId extends BucketGroupId<RawBase64> {
/**
* Constructs a new instance of this class.
@@ -18,6 +18,6 @@ public class RawBucketId extends BucketGroupId<byte[]> {
* @param to The identifying exclusive-to raw buffer.
*/
public RawBucketId(byte[] from, byte[] to) {
- super("raw_bucket", from, Arrays.toString(from), to, Arrays.toString(to));
+ super("raw_bucket", new RawBase64(from), new RawBase64(to));
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/grouping/result/RawId.java b/container-search/src/main/java/com/yahoo/search/grouping/result/RawId.java
index de711d0c218..fd0d38c37fd 100644
--- a/container-search/src/main/java/com/yahoo/search/grouping/result/RawId.java
+++ b/container-search/src/main/java/com/yahoo/search/grouping/result/RawId.java
@@ -1,14 +1,14 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.grouping.result;
-import java.util.Arrays;
+import com.yahoo.prelude.hitfield.RawBase64;
/**
* This class is used in {@link Group} instances where the identifying expression evaluated to a {@link Byte} array.
*
* @author Simon Thoresen Hult
*/
-public class RawId extends ValueGroupId<byte[]> {
+public class RawId extends ValueGroupId<RawBase64> {
/**
* Constructs a new instance of this class.
@@ -16,6 +16,6 @@ public class RawId extends ValueGroupId<byte[]> {
* @param value The identifying byte array.
*/
public RawId(byte[] value) {
- super("raw", value, Arrays.toString(value));
+ super("raw", new RawBase64(value));
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/grouping/vespa/ResultBuilder.java b/container-search/src/main/java/com/yahoo/search/grouping/vespa/ResultBuilder.java
index 7f006b098cd..e746706f9c5 100644
--- a/container-search/src/main/java/com/yahoo/search/grouping/vespa/ResultBuilder.java
+++ b/container-search/src/main/java/com/yahoo/search/grouping/vespa/ResultBuilder.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.grouping.vespa;
+import com.yahoo.prelude.hitfield.RawBase64;
import com.yahoo.search.grouping.Continuation;
import com.yahoo.search.grouping.GroupingRequest;
import com.yahoo.search.grouping.result.BoolId;
@@ -28,6 +29,7 @@ import com.yahoo.searchlib.aggregation.Hit;
import com.yahoo.searchlib.aggregation.HitsAggregationResult;
import com.yahoo.searchlib.aggregation.MaxAggregationResult;
import com.yahoo.searchlib.aggregation.MinAggregationResult;
+import com.yahoo.searchlib.aggregation.RawData;
import com.yahoo.searchlib.aggregation.StandardDeviationAggregationResult;
import com.yahoo.searchlib.aggregation.SumAggregationResult;
import com.yahoo.searchlib.aggregation.XorAggregationResult;
@@ -169,7 +171,7 @@ class ResultBuilder {
} else {
String label = transform.getLabel(result.getTag());
if (label != null) {
- group.setField(label, newResult(result, tag));
+ group.setField(label, convertResult(newResult(result, tag)));
}
}
}
@@ -228,24 +230,27 @@ class ResultBuilder {
return new RawId(res.getRaw());
} else if (res instanceof StringResultNode) {
return new StringId(res.getString());
- } else if (res instanceof FloatBucketResultNode) {
- FloatBucketResultNode bucketId = (FloatBucketResultNode)res;
+ } else if (res instanceof FloatBucketResultNode bucketId) {
return new DoubleBucketId(bucketId.getFrom(), bucketId.getTo());
- } else if (res instanceof IntegerBucketResultNode) {
- IntegerBucketResultNode bucketId = (IntegerBucketResultNode)res;
+ } else if (res instanceof IntegerBucketResultNode bucketId) {
return new LongBucketId(bucketId.getFrom(), bucketId.getTo());
- } else if (res instanceof StringBucketResultNode) {
- StringBucketResultNode bucketId = (StringBucketResultNode)res;
+ } else if (res instanceof StringBucketResultNode bucketId) {
return new StringBucketId(bucketId.getFrom(), bucketId.getTo());
- } else if (res instanceof RawBucketResultNode) {
- RawBucketResultNode bucketId = (RawBucketResultNode)res;
+ } else if (res instanceof RawBucketResultNode bucketId) {
return new RawBucketId(bucketId.getFrom(), bucketId.getTo());
} else {
throw new UnsupportedOperationException(res.getClass().getName());
}
}
- Object newResult(ExpressionNode execResult, int tag) {
+ private Object convertResult(Object value) {
+ if (value instanceof RawData raw) {
+ return new RawBase64(raw.getData());
+ }
+ return value;
+ }
+
+ private Object newResult(ExpressionNode execResult, int tag) {
if (execResult instanceof AverageAggregationResult) {
return ((AverageAggregationResult)execResult).getAverage().getNumber();
} else if (execResult instanceof CountAggregationResult) {
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/Binding.java b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/Binding.java
index 99c3477274d..01bbef13129 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/Binding.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/Binding.java
@@ -48,7 +48,7 @@ public class Binding implements Comparable<Binding> {
for (int i = 0; i <= maxDimensions; i++) {
String value = i < dimensionBinding.getDimensions().size() ? dimensionBinding.getValues().get(i) : null;
if (value == null)
- generality += Math.pow(2, maxDimensions - i-1);
+ generality += (int)Math.pow(2, maxDimensions - i - 1);
else
context.put(dimensionBinding.getDimensions().get(i), value);
}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
index b36c8788877..90f4e6ae65c 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
@@ -32,8 +32,6 @@ import com.yahoo.search.grouping.result.AbstractList;
import com.yahoo.search.grouping.result.BucketGroupId;
import com.yahoo.search.grouping.result.Group;
import com.yahoo.search.grouping.result.GroupId;
-import com.yahoo.search.grouping.result.RawBucketId;
-import com.yahoo.search.grouping.result.RawId;
import com.yahoo.search.grouping.result.RootGroup;
import com.yahoo.search.grouping.result.ValueGroupId;
import com.yahoo.search.result.Coverage;
@@ -57,7 +55,6 @@ import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.ArrayDeque;
-import java.util.Arrays;
import java.util.Deque;
import java.util.Map;
import java.util.Optional;
@@ -420,31 +417,19 @@ public class JsonRenderer extends AsynchronousSectionedRenderer<Result> {
}
protected void renderGroupMetadata(GroupId id) throws IOException {
- if (!(id instanceof ValueGroupId || id instanceof BucketGroupId)) return;
+ if (!(id instanceof ValueGroupId<?> || id instanceof BucketGroupId)) return;
- if (id instanceof ValueGroupId valueId) {
- generator.writeStringField(GROUPING_VALUE, getIdValue(valueId));
+ if (id instanceof ValueGroupId<?> valueId) {
+ generator.writeStringField(GROUPING_VALUE, valueId.getValue().toString());
} else {
BucketGroupId<?> bucketId = (BucketGroupId<?>) id;
generator.writeObjectFieldStart(BUCKET_LIMITS);
- generator.writeStringField(BUCKET_FROM, getBucketFrom(bucketId));
- generator.writeStringField(BUCKET_TO, getBucketTo(bucketId));
+ generator.writeStringField(BUCKET_FROM, bucketId.getFrom().toString());
+ generator.writeStringField(BUCKET_TO, bucketId.getTo().toString());
generator.writeEndObject();
}
}
- private static String getIdValue(ValueGroupId<?> id) {
- return (id instanceof RawId ? Arrays.toString(((RawId) id).getValue()) : id.getValue()).toString();
- }
-
- private static String getBucketFrom(BucketGroupId<?> id) {
- return (id instanceof RawBucketId ? Arrays.toString(((RawBucketId) id).getFrom()) : id.getFrom()).toString();
- }
-
- private static String getBucketTo(BucketGroupId<?> id) {
- return (id instanceof RawBucketId ? Arrays.toString(((RawBucketId) id).getTo()) : id.getTo()).toString();
- }
-
protected void renderTotalHitCount(Hit hit) throws IOException {
if ( ! (getRecursionLevel() == 1 && hit instanceof HitGroup)) return;
diff --git a/container-search/src/test/java/com/yahoo/search/grouping/result/GroupIdTestCase.java b/container-search/src/test/java/com/yahoo/search/grouping/result/GroupIdTestCase.java
index 7b2f0d52742..77ed858b14b 100644
--- a/container-search/src/test/java/com/yahoo/search/grouping/result/GroupIdTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/grouping/result/GroupIdTestCase.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.grouping.result;
+import com.yahoo.prelude.hitfield.RawBase64;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
@@ -25,10 +26,10 @@ public class GroupIdTestCase {
assertEquals(9L, rangeId.getTo());
valueId = new RawId(new byte[]{6, 9});
- assertArrayEquals(new byte[]{6, 9}, (byte[]) valueId.getValue());
+ assertEquals(new RawBase64(new byte[]{6, 9}), valueId.getValue());
rangeId = new RawBucketId(new byte[]{6, 9}, new byte[]{9, 6});
- assertArrayEquals(new byte[]{6, 9}, (byte[]) rangeId.getFrom());
- assertArrayEquals(new byte[]{9, 6}, (byte[]) rangeId.getTo());
+ assertEquals(new RawBase64(new byte[]{6, 9}), rangeId.getFrom());
+ assertEquals(new RawBase64(new byte[]{9, 6}), rangeId.getTo());
valueId = new StringId("69");
assertEquals("69", valueId.getValue());
@@ -47,8 +48,8 @@ public class GroupIdTestCase {
assertEquals("group:long:69", new LongId(69L).toString());
assertEquals("group:long_bucket:6:9", new LongBucketId(6L, 9L).toString());
assertEquals("group:null", new NullId().toString());
- assertEquals("group:raw:[6, 9]", new RawId(new byte[]{6, 9}).toString());
- assertEquals("group:raw_bucket:[6, 9]:[9, 6]", new RawBucketId(new byte[]{6, 9}, new byte[]{9, 6}).toString());
+ assertEquals("group:raw:Bgk", new RawId(new byte[]{6, 9}).toString());
+ assertEquals("group:raw_bucket:Bgk:CQY", new RawBucketId(new byte[]{6, 9}, new byte[]{9, 6}).toString());
assertTrue(new RootId(0).toString().startsWith("group:root:"));
assertEquals("group:string:69", new StringId("69").toString());
assertEquals("group:string_bucket:6:9", new StringBucketId("6", "9").toString());
diff --git a/container-search/src/test/java/com/yahoo/search/grouping/result/HitRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/grouping/result/HitRendererTestCase.java
index 8e98f49df48..69bd848ebcd 100644
--- a/container-search/src/test/java/com/yahoo/search/grouping/result/HitRendererTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/grouping/result/HitRendererTestCase.java
@@ -57,7 +57,7 @@ public class HitRendererTestCase {
"</group>\n");
assertRender(newGroup(new RawId(Utf8.toBytes("foo"))),
"<group relevance=\"1.0\">\n" +
- "<id type=\"raw\">[102, 111, 111]</id>\n" +
+ "<id type=\"raw\">Zm9v</id>\n" +
"</group>\n");
assertRender(newGroup(new StringId("foo")),
"<group relevance=\"1.0\">\n" +
@@ -85,7 +85,7 @@ public class HitRendererTestCase {
"</group>\n");
assertRender(newGroup(new RawBucketId(Utf8.toBytes("bar"), Utf8.toBytes("baz"))),
"<group relevance=\"1.0\">\n" +
- "<id type=\"raw_bucket\">\n<from>[98, 97, 114]</from>\n<to>[98, 97, 122]</to>\n</id>\n" +
+ "<id type=\"raw_bucket\">\n<from>YmFy</from>\n<to>YmF6</to>\n</id>\n" +
"</group>\n");
}
diff --git a/container-search/src/test/java/com/yahoo/search/grouping/vespa/ResultBuilderTestCase.java b/container-search/src/test/java/com/yahoo/search/grouping/vespa/ResultBuilderTestCase.java
index 019a022b7e6..b0b48bb8731 100644
--- a/container-search/src/test/java/com/yahoo/search/grouping/vespa/ResultBuilderTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/grouping/vespa/ResultBuilderTestCase.java
@@ -10,12 +10,39 @@ import com.yahoo.search.grouping.result.GroupList;
import com.yahoo.search.grouping.result.HitList;
import com.yahoo.search.result.HitGroup;
import com.yahoo.search.result.Relevance;
-import com.yahoo.searchlib.aggregation.*;
+import com.yahoo.searchlib.aggregation.AggregationResult;
+import com.yahoo.searchlib.aggregation.AverageAggregationResult;
+import com.yahoo.searchlib.aggregation.CountAggregationResult;
+import com.yahoo.searchlib.aggregation.ExpressionCountAggregationResult;
+import com.yahoo.searchlib.aggregation.FS4Hit;
+import com.yahoo.searchlib.aggregation.Group;
+import com.yahoo.searchlib.aggregation.Grouping;
+import com.yahoo.searchlib.aggregation.HitsAggregationResult;
+import com.yahoo.searchlib.aggregation.MaxAggregationResult;
+import com.yahoo.searchlib.aggregation.MinAggregationResult;
+import com.yahoo.searchlib.aggregation.SumAggregationResult;
+import com.yahoo.searchlib.aggregation.XorAggregationResult;
import com.yahoo.searchlib.aggregation.hll.SparseSketch;
-import com.yahoo.searchlib.expression.*;
+import com.yahoo.searchlib.expression.FloatBucketResultNode;
+import com.yahoo.searchlib.expression.FloatResultNode;
+import com.yahoo.searchlib.expression.IntegerBucketResultNode;
+import com.yahoo.searchlib.expression.IntegerResultNode;
+import com.yahoo.searchlib.expression.NullResultNode;
+import com.yahoo.searchlib.expression.RawBucketResultNode;
+import com.yahoo.searchlib.expression.RawResultNode;
+import com.yahoo.searchlib.expression.ResultNode;
+import com.yahoo.searchlib.expression.StringBucketResultNode;
+import com.yahoo.searchlib.expression.StringResultNode;
import org.junit.jupiter.api.Test;
-import java.util.*;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@@ -33,19 +60,19 @@ public class ResultBuilderTestCase {
assertGroupId("group:6.9", new FloatResultNode(6.9));
assertGroupId("group:69", new IntegerResultNode(69));
assertGroupId("group:null", new NullResultNode());
- assertGroupId("group:[6, 9]", new RawResultNode(new byte[]{6, 9}));
+ assertGroupId("group:Bgk", new RawResultNode(new byte[]{6, 9}));
assertGroupId("group:a", new StringResultNode("a"));
assertGroupId("group:6.9:9.6", new FloatBucketResultNode(6.9, 9.6));
assertGroupId("group:6:9", new IntegerBucketResultNode(6, 9));
assertGroupId("group:a:b", new StringBucketResultNode("a", "b"));
- assertGroupId("group:[6, 9]:[9, 6]", new RawBucketResultNode(new RawResultNode(new byte[]{6, 9}),
+ assertGroupId("group:Bgk:CQY", new RawBucketResultNode(new RawResultNode(new byte[]{6, 9}),
new RawResultNode(new byte[]{9, 6})));
}
@Test
void requireThatUnknownGroupIdThrows() {
assertBuildFail("all(group(a) each(output(count())))",
- Arrays.asList(newGrouping(new Group().setTag(2).setId(new MyResultNode()))),
+ List.of(newGrouping(new Group().setTag(2).setId(new MyResultNode()))),
"com.yahoo.search.grouping.vespa.ResultBuilderTestCase$MyResultNode");
}
@@ -61,9 +88,17 @@ public class ResultBuilderTestCase {
}
@Test
+ void requireThatAllBasicResultsCanBeConverted() {
+ assertResult("69", new MinAggregationResult(new IntegerResultNode(69)));
+ assertResult("69.3", new MinAggregationResult(new FloatResultNode(69.3)));
+ assertResult("69.6", new MinAggregationResult(new StringResultNode("69.6")));
+ assertResult("Bgk", new MinAggregationResult(new RawResultNode(new byte[]{6,9})));
+ }
+
+ @Test
void requireThatUnknownExpressionNodeThrows() {
assertBuildFail("all(group(a) each(output(count())))",
- Arrays.asList(newGrouping(newGroup(2, 2, new MyAggregationResult().setTag(3)))),
+ List.of(newGrouping(newGroup(2, 2, new MyAggregationResult().setTag(3)))),
"com.yahoo.search.grouping.vespa.ResultBuilderTestCase$MyAggregationResult");
}
@@ -127,10 +162,10 @@ public class ResultBuilderTestCase {
@Test
void requireThatParallelResultsAreTransformed() {
assertBuild("all(group(foo) each(output(count())) as(bar) each(output(count())) as(baz))",
- Arrays.asList(new Grouping().setRoot(newGroup(1, 0)),
+ List.of(new Grouping().setRoot(newGroup(1, 0)),
new Grouping().setRoot(newGroup(1, 0))));
assertBuildFail("all(group(foo) each(output(count())) as(bar) each(output(count())) as(baz))",
- Arrays.asList(new Grouping().setRoot(newGroup(2)),
+ List.of(new Grouping().setRoot(newGroup(2)),
new Grouping().setRoot(newGroup(3))),
"Expected 1 group, got 2.");
}
@@ -138,15 +173,15 @@ public class ResultBuilderTestCase {
@Test
void requireThatTagsAreHandledCorrectly() {
assertBuild("all(group(a) each(output(count())))",
- Arrays.asList(newGrouping(
+ List.of(newGrouping(
newGroup(7, new CountAggregationResult(0)))));
}
@Test
void requireThatEmptyBranchesArePruned() {
- assertBuildFail("all()", Collections.<Grouping>emptyList(), "Expected 1 group, got 0.");
- assertBuildFail("all(group(a))", Collections.<Grouping>emptyList(), "Expected 1 group, got 0.");
- assertBuildFail("all(group(a) each())", Collections.<Grouping>emptyList(), "Expected 1 group, got 0.");
+ assertBuildFail("all()", List.of(), "Expected 1 group, got 0.");
+ assertBuildFail("all(group(a))", List.of(), "Expected 1 group, got 0.");
+ assertBuildFail("all(group(a) each())", List.of(), "Expected 1 group, got 0.");
Grouping grouping = newGrouping(newGroup(2, new CountAggregationResult(69).setTag(3)));
String expectedOutput = "RootGroup{id=group:root}[GroupList{label=a}[Group{id=group:2, count()=69}[]]]";
@@ -189,14 +224,14 @@ public class ResultBuilderTestCase {
"HitList{label=bar}[Hit{id=hit:1}, Hit{id=hit:2}]]]]");
assertLayout("all(group(foo) each(each(output(summary())) as(bar)" +
" each(output(summary())) as(baz)))",
- Arrays.asList(newGrouping(newGroup(2, newHitList(3, 2))),
+ List.of(newGrouping(newGroup(2, newHitList(3, 2))),
newGrouping(newGroup(2, newHitList(4, 2)))),
"RootGroup{id=group:root}[GroupList{label=foo}[Group{id=group:2}[" +
"HitList{label=bar}[Hit{id=hit:1}, Hit{id=hit:2}], " +
"HitList{label=baz}[Hit{id=hit:1}, Hit{id=hit:2}]]]]");
assertLayout("all(group(foo) each(each(output(summary())))" +
" each(each(output(summary()))) as(bar))",
- Arrays.asList(newGrouping(newGroup(2, newHitList(3, 2))),
+ List.of(newGrouping(newGroup(2, newHitList(3, 2))),
newGrouping(newGroup(4, newHitList(5, 2)))),
"RootGroup{id=group:root}[" +
"GroupList{label=foo}[Group{id=group:2}[HitList{label=hits}[Hit{id=hit:1}, Hit{id=hit:2}]]], " +
@@ -273,18 +308,18 @@ public class ResultBuilderTestCase {
assertResultCont("all(group(a) max(2) each(output(count())) as(foo)" +
" each(output(count())) as(bar))",
- Arrays.asList(newGrouping(newGroup(2, 1, new CountAggregationResult(1))),
+ List.of(newGrouping(newGroup(2, 1, new CountAggregationResult(1))),
newGrouping(newGroup(4, 2, new CountAggregationResult(4)))),
"[]");
assertResultCont("all(group(a) max(2) each(output(count())) as(foo)" +
" each(output(count())) as(bar))",
- Arrays.asList(newGrouping(newGroup(2, 1, new CountAggregationResult(1))),
+ List.of(newGrouping(newGroup(2, 1, new CountAggregationResult(1))),
newGrouping(newGroup(4, 2, new CountAggregationResult(4)))),
newOffset(newResultId(0), 2, 1),
"[0=1]");
assertResultCont("all(group(a) max(2) each(output(count())) as(foo)" +
" each(output(count())) as(bar))",
- Arrays.asList(newGrouping(newGroup(2, 1, new CountAggregationResult(1))),
+ List.of(newGrouping(newGroup(2, 1, new CountAggregationResult(1))),
newGrouping(newGroup(4, 2, new CountAggregationResult(4)))),
newComposite(newOffset(newResultId(0), 2, 2),
newOffset(newResultId(1), 4, 1)),
@@ -299,18 +334,18 @@ public class ResultBuilderTestCase {
assertResultCont("all(group(a) each(max(2) each(output(summary()))) as(foo)" +
" each(max(2) each(output(summary()))) as(bar))",
- Arrays.asList(newGrouping(newGroup(2, newHitList(3, 4))),
+ List.of(newGrouping(newGroup(2, newHitList(3, 4))),
newGrouping(newGroup(4, newHitList(5, 4)))),
"[]");
assertResultCont("all(group(a) each(max(2) each(output(summary()))) as(foo)" +
" each(max(2) each(output(summary()))) as(bar))",
- Arrays.asList(newGrouping(newGroup(2, newHitList(3, 4))),
+ List.of(newGrouping(newGroup(2, newHitList(3, 4))),
newGrouping(newGroup(4, newHitList(5, 4)))),
newOffset(newResultId(0, 0, 0), 3, 1),
"[0.0.0=1]");
assertResultCont("all(group(a) each(max(2) each(output(summary()))) as(foo)" +
" each(max(2) each(output(summary()))) as(bar))",
- Arrays.asList(newGrouping(newGroup(2, newHitList(3, 4))),
+ List.of(newGrouping(newGroup(2, newHitList(3, 4))),
newGrouping(newGroup(4, newHitList(5, 4)))),
newComposite(newOffset(newResultId(0, 0, 0), 3, 2),
newOffset(newResultId(1, 0, 0), 5, 1)),
@@ -404,7 +439,7 @@ public class ResultBuilderTestCase {
void requireThatGroupListContinuationsCanBeSetInSiblingGroupLists() {
String request = "all(group(a) max(2) each(output(count())) as(foo)" +
" each(output(count())) as(bar))";
- List<Grouping> result = Arrays.asList(newGrouping(newGroup(2, 1, new CountAggregationResult(1)),
+ List<Grouping> result = List.of(newGrouping(newGroup(2, 1, new CountAggregationResult(1)),
newGroup(2, 2, new CountAggregationResult(2)),
newGroup(2, 3, new CountAggregationResult(3)),
newGroup(2, 4, new CountAggregationResult(4))),
@@ -646,7 +681,7 @@ public class ResultBuilderTestCase {
void requireThatHitListContinuationsCanBeSetInSiblingHitLists() {
String request = "all(group(a) each(max(2) each(output(summary()))) as(foo)" +
" each(max(2) each(output(summary()))) as(bar))";
- List<Grouping> result = Arrays.asList(newGrouping(newGroup(2, newHitList(3, 4))),
+ List<Grouping> result = List.of(newGrouping(newGroup(2, newHitList(3, 4))),
newGrouping(newGroup(4, newHitList(5, 4))));
assertContinuation(request, result, newComposite(newOffset(newResultId(0, 0, 0), 3, 0),
newOffset(newResultId(1, 0, 0), 5, 5)),
@@ -839,7 +874,7 @@ public class ResultBuilderTestCase {
}
private static void assertResultCont(String request, Grouping result, Continuation cont, String expected) {
- assertOutput(request, Arrays.asList(result), cont, new ResultContWriter(), expected);
+ assertOutput(request, List.of(result), cont, new ResultContWriter(), expected);
}
private static void assertResultCont(String request, List<Grouping> result, String expected) {
@@ -851,11 +886,11 @@ public class ResultBuilderTestCase {
}
private static void assertContinuation(String request, Grouping result, String expected) {
- assertOutput(request, Arrays.asList(result), null, new ContinuationWriter(), expected);
+ assertOutput(request, List.of(result), null, new ContinuationWriter(), expected);
}
private static void assertContinuation(String request, Grouping result, Continuation cont, String expected) {
- assertOutput(request, Arrays.asList(result), cont, new ContinuationWriter(), expected);
+ assertOutput(request, List.of(result), cont, new ContinuationWriter(), expected);
}
private static void assertContinuation(String request, List<Grouping> result, Continuation cont, String expected) {
@@ -863,7 +898,7 @@ public class ResultBuilderTestCase {
}
private static void assertLayout(String request, Grouping result, String expected) {
- assertOutput(request, Arrays.asList(result), null, new LayoutWriter(), expected);
+ assertOutput(request, List.of(result), null, new LayoutWriter(), expected);
}
private static void assertLayout(String request, List<Grouping> result, String expected) {
@@ -953,8 +988,7 @@ public class ResultBuilderTestCase {
}
String toString(Continuation cnt) {
- if (cnt instanceof OffsetContinuation) {
- OffsetContinuation off = (OffsetContinuation)cnt;
+ if (cnt instanceof OffsetContinuation off) {
String id = off.getResultId().toString().replace(", ", ".");
return id.substring(5, id.length() - 1) + "=" + off.getOffset();
} else if (cnt instanceof CompositeContinuation) {
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/ApplicationVersion.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/ApplicationVersion.java
index 04604ae7007..eb2005bf268 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/ApplicationVersion.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/ApplicationVersion.java
@@ -1,11 +1,9 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.controller.api.integration.deployment;
-import ai.vespa.validation.Validation;
import com.yahoo.component.Version;
import java.time.Instant;
-import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
@@ -33,6 +31,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> {
private final Optional<String> sourceUrl;
private final Optional<String> commit;
private final Optional<String> bundleHash;
+ private final Optional<Instant> obsoleteAt;
private final boolean hasPackage;
private final boolean shouldSkip;
private final Optional<String> description;
@@ -41,7 +40,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> {
public ApplicationVersion(RevisionId id, Optional<SourceRevision> source, Optional<String> authorEmail,
Optional<Version> compileVersion, Optional<Integer> allowedMajor, Optional<Instant> buildTime,
Optional<String> sourceUrl, Optional<String> commit, Optional<String> bundleHash,
- boolean hasPackage, boolean shouldSkip, Optional<String> description, int risk) {
+ Optional<Instant> obsoleteAt, boolean hasPackage, boolean shouldSkip, Optional<String> description, int risk) {
if (commit.isPresent() && commit.get().length() > 128)
throw new IllegalArgumentException("Commit may not be longer than 128 characters");
@@ -61,6 +60,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> {
this.sourceUrl = requireNonNull(sourceUrl, "sourceUrl cannot be null");
this.commit = requireNonNull(commit, "commit cannot be null");
this.bundleHash = bundleHash;
+ this.obsoleteAt = obsoleteAt;
this.hasPackage = hasPackage;
this.shouldSkip = shouldSkip;
this.description = description;
@@ -71,19 +71,9 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> {
return id;
}
- /** Create an application package version from a completed build, without an author email */
- public static ApplicationVersion from(RevisionId id, SourceRevision source) {
- return new ApplicationVersion(id, Optional.of(source), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true, false, Optional.empty(), 0);
- }
-
- /** Creates a version from a completed build, an author email, and build metadata. */
- public static ApplicationVersion from(RevisionId id, SourceRevision source, String authorEmail, Version compileVersion, Instant buildTime) {
- return new ApplicationVersion(id, Optional.of(source), Optional.of(authorEmail), Optional.of(compileVersion), Optional.empty(), Optional.of(buildTime), Optional.empty(), Optional.empty(), Optional.empty(), true, false, Optional.empty(), 0);
- }
-
/** Creates a minimal version for a development build. */
public static ApplicationVersion forDevelopment(RevisionId id, Optional<Version> compileVersion, Optional<Integer> allowedMajor) {
- return new ApplicationVersion(id, Optional.empty(), Optional.empty(), compileVersion, allowedMajor, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true, false, Optional.empty(), 0);
+ return new ApplicationVersion(id, Optional.empty(), Optional.empty(), compileVersion, allowedMajor, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true, false, Optional.empty(), 0);
}
/** Creates a version from a completed build, an author email, and build metadata. */
@@ -91,7 +81,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> {
Optional<Version> compileVersion, Optional<Integer> allowedMajor, Optional<Instant> buildTime, Optional<String> sourceUrl,
Optional<String> commit, Optional<String> bundleHash, Optional<String> description, int risk) {
return new ApplicationVersion(id, source, authorEmail, compileVersion, allowedMajor, buildTime,
- sourceUrl, commit, bundleHash, true, false, description, risk);
+ sourceUrl, commit, bundleHash, Optional.empty(), true, false, description, risk);
}
/** Returns a unique identifier for this version or "unknown" if version is not known */
@@ -150,7 +140,17 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> {
/** Returns a copy of this without a package stored. */
public ApplicationVersion withoutPackage() {
- return new ApplicationVersion(id, source, authorEmail, compileVersion, allowedMajor, buildTime, sourceUrl, commit, bundleHash, false, shouldSkip, description, risk);
+ return new ApplicationVersion(id, source, authorEmail, compileVersion, allowedMajor, buildTime, sourceUrl, commit, bundleHash, obsoleteAt, false, shouldSkip, description, risk);
+ }
+
+ /** Returns a copy of this which is obsolete now. */
+ public ApplicationVersion obsoleteAt(Instant now) {
+ return new ApplicationVersion(id, source, authorEmail, compileVersion, allowedMajor, buildTime, sourceUrl, commit, bundleHash, Optional.of(now), hasPackage, shouldSkip, description, risk);
+ }
+
+ /** Returns the instant at which this became obsolete, i.e., no longer relevant for automated deployments. */
+ public Optional<Instant> obsoleteAt() {
+ return obsoleteAt;
}
/** Whether we still have the package for this revision. */
@@ -160,7 +160,7 @@ public class ApplicationVersion implements Comparable<ApplicationVersion> {
/** Returns a copy of this which will not be rolled out to production. */
public ApplicationVersion skipped() {
- return new ApplicationVersion(id, source, authorEmail, compileVersion, allowedMajor, buildTime, sourceUrl, commit, bundleHash, hasPackage, true, description, risk);
+ return new ApplicationVersion(id, source, authorEmail, compileVersion, allowedMajor, buildTime, sourceUrl, commit, bundleHash, obsoleteAt, hasPackage, true, description, risk);
}
/** Whether we still have the package for this revision. */
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Change.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Change.java
index 64cad599168..5ebb3d53529 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Change.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Change.java
@@ -2,7 +2,6 @@
package com.yahoo.vespa.hosted.controller.application;
import com.yahoo.component.Version;
-import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationVersion;
import com.yahoo.vespa.hosted.controller.api.integration.deployment.RevisionId;
import java.util.Objects;
@@ -23,7 +22,7 @@ import static java.util.Objects.requireNonNull;
*/
public final class Change {
- private static final Change empty = new Change(Optional.empty(), Optional.empty(), false);
+ private static final Change empty = new Change(Optional.empty(), Optional.empty(), false, false);
/** The platform version we are upgrading to, or empty if none */
private final Optional<Version> platform;
@@ -32,23 +31,27 @@ public final class Change {
private final Optional<RevisionId> revision;
/** Whether this change is a pin to its contained Vespa version, or to the application's current. */
- private final boolean pinned;
+ private final boolean platformPinned;
- private Change(Optional<Version> platform, Optional<RevisionId> revision, boolean pinned) {
+ /** Whether this change is a pin to its contained application revision, or to the application's current. */
+ private final boolean revisionPinned;
+
+ private Change(Optional<Version> platform, Optional<RevisionId> revision, boolean platformPinned, boolean revisionPinned) {
this.platform = requireNonNull(platform, "platform cannot be null");
this.revision = requireNonNull(revision, "revision cannot be null");
if (revision.isPresent() && ( ! revision.get().isProduction())) {
throw new IllegalArgumentException("Application version to deploy must be a known version");
}
- this.pinned = pinned;
+ this.platformPinned = platformPinned;
+ this.revisionPinned = revisionPinned;
}
public Change withoutPlatform() {
- return new Change(Optional.empty(), revision, pinned);
+ return new Change(Optional.empty(), revision, platformPinned, revisionPinned);
}
public Change withoutApplication() {
- return new Change(platform, Optional.empty(), pinned);
+ return new Change(platform, Optional.empty(), platformPinned, revisionPinned);
}
/** Returns whether a change should currently be deployed */
@@ -58,7 +61,7 @@ public final class Change {
/** Returns whether this is the empty change. */
public boolean isEmpty() {
- return ! hasTargets() && ! pinned;
+ return ! hasTargets() && ! platformPinned && ! revisionPinned;
}
/** Returns the platform version carried by this. */
@@ -67,42 +70,55 @@ public final class Change {
/** Returns the application version carried by this. */
public Optional<RevisionId> revision() { return revision; }
- public boolean isPinned() { return pinned; }
+ public boolean isPlatformPinned() { return platformPinned; }
+
+ public boolean isRevisionPinned() { return revisionPinned; }
/** Returns an instance representing no change */
public static Change empty() { return empty; }
/** Returns a version of this change which replaces or adds this platform change */
public Change with(Version platformVersion) {
- if (pinned)
+ if (platformPinned)
throw new IllegalArgumentException("Not allowed to set a platform version when pinned.");
- return new Change(Optional.of(platformVersion), revision, pinned);
+ return new Change(Optional.of(platformVersion), revision, platformPinned, revisionPinned);
}
/** Returns a version of this change which replaces or adds this revision change */
public Change with(RevisionId revision) {
- return new Change(platform, Optional.of(revision), pinned);
+ if (revisionPinned)
+ throw new IllegalArgumentException("Not allowed to set a revision when pinned.");
+
+ return new Change(platform, Optional.of(revision), platformPinned, revisionPinned);
+ }
+
+ /** Returns a change with the versions of this, and with the platform version pinned. */
+ public Change withPlatformPin() {
+ return new Change(platform, revision, true, revisionPinned);
+ }
+
+ /** Returns a change with the versions of this, and with the platform version unpinned. */
+ public Change withoutPlatformPin() {
+ return new Change(platform, revision, false, revisionPinned);
}
/** Returns a change with the versions of this, and with the platform version pinned. */
- public Change withPin() {
- return new Change(platform, revision, true);
+ public Change withRevisionPin() {
+ return new Change(platform, revision, platformPinned, true);
}
/** Returns a change with the versions of this, and with the platform version unpinned. */
- public Change withoutPin() {
- return new Change(platform, revision, false);
+ public Change withoutRevisionPin() {
+ return new Change(platform, revision, platformPinned, false);
}
/** Returns the change obtained when overwriting elements of the given change with any present in this */
public Change onTopOf(Change other) {
- if (platform.isPresent())
- other = other.with(platform.get());
- if (revision.isPresent())
- other = other.with(revision.get());
- if (pinned)
- other = other.withPin();
+ if (platform.isPresent()) other = other.with(platform.get());
+ if (revision.isPresent()) other = other.with(revision.get());
+ if (platformPinned) other = other.withPlatformPin();
+ if (revisionPinned) other = other.withRevisionPin();
return other;
}
@@ -111,34 +127,38 @@ public final class Change {
if (this == o) return true;
if (!(o instanceof Change)) return false;
Change change = (Change) o;
- return pinned == change.pinned &&
+ return platformPinned == change.platformPinned &&
+ revisionPinned == change.revisionPinned &&
Objects.equals(platform, change.platform) &&
Objects.equals(revision, change.revision);
}
@Override
public int hashCode() {
- return Objects.hash(platform, revision, pinned);
+ return Objects.hash(platform, revision, platformPinned, revisionPinned);
}
@Override
public String toString() {
StringJoiner changes = new StringJoiner(" and ");
- if (pinned)
+ if (platformPinned)
changes.add("pin to " + platform.map(Version::toString).orElse("current platform"));
else
platform.ifPresent(version -> changes.add("upgrade to " + version));
- revision.ifPresent(revision -> changes.add("revision change to " + revision));
+ if (revisionPinned)
+ changes.add("pin to " + revision.map(RevisionId::toString).orElse("current revision"));
+ else
+ revision.ifPresent(revision -> changes.add("revision change to " + revision));
changes.setEmptyValue("no change");
return changes.toString();
}
public static Change of(RevisionId revision) {
- return new Change(Optional.empty(), Optional.of(revision), false);
+ return new Change(Optional.empty(), Optional.of(revision), false, false);
}
public static Change of(Version platformChange) {
- return new Change(Optional.of(platformChange), Optional.empty(), false);
+ return new Change(Optional.of(platformChange), Optional.empty(), false, false);
}
/** Returns whether this change carries a revision downgrade relative to the given revision. */
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/InstanceList.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/InstanceList.java
index c1bf083b26c..b94779994e4 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/InstanceList.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/InstanceList.java
@@ -125,7 +125,7 @@ public class InstanceList extends AbstractFilteringList<ApplicationId, InstanceL
/** Returns the subset of instances which are not pinned to a certain Vespa version. */
public InstanceList unpinned() {
- return matching(id -> ! instance(id).change().isPinned());
+ return matching(id -> ! instance(id).change().isPlatformPinned());
}
/** Returns the subset of instances which are currently failing a job. */
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java
index 00da34fe2e4..0f1bbfeb25e 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java
@@ -229,7 +229,7 @@ public class DeploymentStatus {
.anyMatch(deployment -> ! compatibleWithCompileVersion.test(deployment.version()))) {
for (Version platform : targetsForPolicy(versionStatus, systemVersion, application.deploymentSpec().requireInstance(instance).upgradePolicy()))
if (compatibleWithCompileVersion.test(platform))
- return change.withoutPin().with(platform);
+ return change.withoutPlatformPin().with(platform);
}
return change;
}
@@ -265,7 +265,7 @@ public class DeploymentStatus {
for (InstanceName instance : application.deploymentSpec().instanceNames()) {
Change outstanding = outstandingChange(instance);
if (outstanding.hasTargets())
- outstandingChanges.put(instance, outstanding.onTopOf(application.require(instance).change()));
+ outstandingChanges.put(instance, outstanding.onTopOf(application.require(instance).change().withoutRevisionPin()));
}
var testJobs = jobsToRun(outstandingChanges, true).entrySet().stream()
.filter(entry -> ! entry.getKey().type().isProduction());
@@ -596,7 +596,8 @@ public class DeploymentStatus {
/** Changes to deploy with the given job, possibly split in two steps. */
private List<Change> changes(JobId job, StepStatus step, Change change) {
- if (change.platform().isEmpty() || change.revision().isEmpty() || change.isPinned())
+ if ( change.platform().isEmpty() || change.revision().isEmpty()
+ || change.isPlatformPinned() || change.isRevisionPinned())
return List.of(change);
if ( step.completedAt(change.withoutApplication(), Optional.of(job)).isPresent()
@@ -1090,14 +1091,14 @@ public class DeploymentStatus {
/** Complete if deployment is on pinned version, and last successful deployment, or if given versions is strictly a downgrade, and this isn't forced by a pin. */
@Override
Optional<Instant> completedAt(Change change, Optional<JobId> dependent) {
- if ( change.isPinned()
+ if ( change.isPlatformPinned()
&& change.platform().isPresent()
&& ! existingDeployment.map(Deployment::version).equals(change.platform()))
return Optional.empty();
if ( change.revision().isPresent()
- && ! existingDeployment.map(Deployment::revision).equals(change.revision())
- && dependent.equals(job())) // Job should (re-)run in this case, but other dependents need not wait.
+ && change.isRevisionPinned()
+ && ! existingDeployment.map(Deployment::revision).equals(change.revision()))
return Optional.empty();
Change fullChange = status.application().require(job.id().application().instance()).change();
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java
index 00a0e22f87d..4e699f2c28f 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java
@@ -41,7 +41,6 @@ import java.util.logging.Logger;
import java.util.stream.Collectors;
import static java.util.Comparator.comparing;
-import static java.util.Comparator.comparingDouble;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toMap;
@@ -331,15 +330,14 @@ public class DeploymentTrigger {
/** Cancels the indicated part of the given application's change. */
public void cancelChange(ApplicationId instanceId, ChangesToCancel cancellation) {
applications().lockApplicationOrThrow(TenantAndApplicationId.from(instanceId), application -> {
- Change change;
- switch (cancellation) {
- case ALL: change = Change.empty(); break;
- case VERSIONS: change = Change.empty().withPin(); break;
- case PLATFORM: change = application.get().require(instanceId.instance()).change().withoutPlatform(); break;
- case APPLICATION: change = application.get().require(instanceId.instance()).change().withoutApplication(); break;
- case PIN: change = application.get().require(instanceId.instance()).change().withoutPin(); break;
- default: throw new IllegalArgumentException("Unknown cancellation choice '" + cancellation + "'!");
- }
+ Change change = switch (cancellation) {
+ case ALL -> Change.empty();
+ case PLATFORM -> application.get().require(instanceId.instance()).change().withoutPlatform();
+ case APPLICATION -> application.get().require(instanceId.instance()).change().withoutApplication();
+ case PIN -> application.get().require(instanceId.instance()).change().withoutPlatformPin();
+ case PLATFORM_PIN -> application.get().require(instanceId.instance()).change().withoutPlatformPin();
+ case APPLICATION_PIN -> application.get().require(instanceId.instance()).change().withoutRevisionPin();
+ };
applications().store(application.with(instanceId.instance(),
instance -> withRemainingChange(instance,
change,
@@ -348,7 +346,7 @@ public class DeploymentTrigger {
});
}
- public enum ChangesToCancel { ALL, PLATFORM, APPLICATION, VERSIONS, PIN }
+ public enum ChangesToCancel { ALL, PLATFORM, APPLICATION, PIN, PLATFORM_PIN, APPLICATION_PIN }
// ---------- Conveniences ----------
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java
index 7b1a1e879d6..52ddcfd5171 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java
@@ -437,7 +437,7 @@ public class InternalStepRunner implements StepRunner {
Version targetPlatform = controller.jobController().run(id).versions().targetPlatform();
Version systemVersion = controller.readSystemVersion();
boolean incompatible = controller.applications().versionCompatibility(id.application()).refuse(targetPlatform, systemVersion);
- return incompatible || application(id.application()).change().isPinned() ? targetPlatform : systemVersion;
+ return incompatible || application(id.application()).change().isPlatformPinned() ? targetPlatform : systemVersion;
}
private Optional<RunStatus> installTester(RunId id, DualLogger logger) {
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java
index 10e4052f067..318a6ffe820 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java
@@ -111,6 +111,7 @@ import static java.util.logging.Level.WARNING;
public class JobController {
public static final Duration maxHistoryAge = Duration.ofDays(60);
+ public static final Duration obsoletePackageExpiry = Duration.ofDays(7);
private static final Logger log = Logger.getLogger(JobController.class.getName());
@@ -165,8 +166,8 @@ public class JobController {
return Optional.empty();
return active(id).isPresent()
- ? Optional.of(logs.readActive(id.application(), id.type(), after))
- : logs.readFinished(id, after);
+ ? Optional.of(logs.readActive(id.application(), id.type(), after))
+ : logs.readFinished(id, after);
}
}
@@ -284,10 +285,10 @@ public class JobController {
private Optional<InputStream> getVespaLogsFromLogserver(Run run, long fromMillis, boolean tester) {
return deploymentCompletedAt(run, tester).map(at ->
- controller.serviceRegistry().configServer().getLogs(new DeploymentId(tester ? run.id().tester().id() : run.id().application(),
- run.id().type().zone()),
- Map.of("from", Long.toString(Math.max(fromMillis, at.toEpochMilli())),
- "to", Long.toString(run.end().orElse(controller.clock().instant()).toEpochMilli()))));
+ controller.serviceRegistry().configServer().getLogs(new DeploymentId(tester ? run.id().tester().id() : run.id().application(),
+ run.id().type().zone()),
+ Map.of("from", Long.toString(Math.max(fromMillis, at.toEpochMilli())),
+ "to", Long.toString(run.end().orElse(controller.clock().instant()).toEpochMilli()))));
}
/** Fetches any new test log entries, and records the id of the last of these, for continuation. */
@@ -509,14 +510,14 @@ public class JobController {
long successes = runs.values().stream().filter(Run::hasSucceeded).count();
var oldEntries = runs.entrySet().iterator();
for (var old = oldEntries.next();
- old.getKey().number() <= last - historyLength
+ old.getKey().number() <= last - historyLength
|| old.getValue().start().isBefore(controller.clock().instant().minus(maxHistoryAge));
old = oldEntries.next()) {
// Make sure we keep the last success and the first failing
if ( successes == 1
- && old.getValue().hasSucceeded()
- && ! old.getValue().start().isBefore(controller.clock().instant().minus(maxHistoryAge))) {
+ && old.getValue().hasSucceeded()
+ && ! old.getValue().start().isBefore(controller.clock().instant().minus(maxHistoryAge))) {
oldEntries.next();
continue;
}
@@ -624,7 +625,7 @@ public class JobController {
});
}
- private LockedApplication withPrunedPackages(LockedApplication application, RevisionId latest){
+ private LockedApplication withPrunedPackages(LockedApplication application, RevisionId latest) {
TenantAndApplicationId id = application.get().id();
Application wrapped = application.get();
RevisionId oldestDeployed = application.get().oldestDeployedRevision()
@@ -632,11 +633,28 @@ public class JobController {
.flatMap(instance -> instance.change().revision().stream())
.min(naturalOrder()))
.orElse(latest);
- controller.applications().applicationStore().prune(id.tenant(), id.application(), oldestDeployed);
+ RevisionId oldestToKeep = null;
+ Instant now = controller.clock().instant();
+ for (ApplicationVersion version : application.get().revisions().withPackage()) {
+ if (version.id().compareTo(oldestDeployed) < 0) {
+ if (version.obsoleteAt().isEmpty()) {
+ application = application.withRevisions(revisions -> revisions.with(version.obsoleteAt(now)));
+ if (oldestToKeep == null)
+ oldestToKeep = version.id();
+ }
+ else {
+ if (oldestToKeep == null && !version.obsoleteAt().get().isBefore(now.minus(obsoletePackageExpiry)))
+ oldestToKeep = version.id();
+ }
+ }
+ }
- for (ApplicationVersion version : application.get().revisions().withPackage())
- if (version.id().compareTo(oldestDeployed) < 0)
- application = application.withRevisions(revisions -> revisions.with(version.withoutPackage()));
+ if (oldestToKeep != null) {
+ controller.applications().applicationStore().prune(id.tenant(), id.application(), oldestToKeep);
+ for (ApplicationVersion version : application.get().revisions().withPackage())
+ if (version.id().compareTo(oldestToKeep) < 0)
+ application = application.withRevisions(revisions -> revisions.with(version.withoutPackage()));
+ }
return application;
}
@@ -703,8 +721,8 @@ public class JobController {
VersionStatus versionStatus = controller.readVersionStatus();
if ( ! controller.system().isCd()
- && platform.isPresent()
- && versionStatus.deployableVersions().stream().map(VespaVersion::versionNumber).noneMatch(platform.get()::equals))
+ && platform.isPresent()
+ && versionStatus.deployableVersions().stream().map(VespaVersion::versionNumber).noneMatch(platform.get()::equals))
throw new IllegalArgumentException("platform version " + platform.get() + " is not present in this system");
controller.applications().lockApplicationOrThrow(TenantAndApplicationId.from(id), application -> {
@@ -731,8 +749,8 @@ public class JobController {
controller.applications().lockApplicationOrThrow(TenantAndApplicationId.from(id), application -> {
Version targetPlatform = platform.orElseGet(() -> findTargetPlatform(applicationPackage, deploymentId, application.get().get(id.instance()), versionStatus));
if ( ! allowOutdatedPlatform
- && ! controller.readVersionStatus().isOnCurrentMajor(targetPlatform)
- && runs(id, type).values().stream().noneMatch(run -> run.versions().targetPlatform().getMajor() == targetPlatform.getMajor()))
+ && ! controller.readVersionStatus().isOnCurrentMajor(targetPlatform)
+ && runs(id, type).values().stream().noneMatch(run -> run.versions().targetPlatform().getMajor() == targetPlatform.getMajor()))
throw new IllegalArgumentException("platform version " + targetPlatform + " is not on a current major version in this system");
controller.applications().applicationStore().putDev(deploymentId, version.id(), applicationPackage.zippedContent(), diff);
@@ -872,7 +890,7 @@ public class JobController {
/** Locks all runs and modifies the list of historic runs for the given application and job type. */
private void locked(ApplicationId id, JobType type, Consumer<SortedMap<RunId, Run>> modifications) {
- try (Mutex __ = curator.lock(id, type)) {
+ try (Mutex __ = curator.lock(id, type)) {
SortedMap<RunId, Run> runs = new TreeMap<>(curator.readHistoricRuns(id, type));
modifications.accept(runs);
curator.writeHistoricRuns(id, type, runs.values());
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/RevisionHistory.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/RevisionHistory.java
index bbab9487ea2..272417ba0ac 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/RevisionHistory.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/RevisionHistory.java
@@ -93,7 +93,7 @@ public class RevisionHistory {
// Fallback for when an application version isn't known for the given key.
private static ApplicationVersion revisionOf(RevisionId id) {
- return new ApplicationVersion(id, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), false, false, Optional.empty(), 0);
+ return new ApplicationVersion(id, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), false, false, Optional.empty(), 0);
}
/** Returns the production {@link ApplicationVersion} with this revision ID. */
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/Versions.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/Versions.java
index e7371561636..f752e396c09 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/Versions.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/Versions.java
@@ -126,7 +126,7 @@ public class Versions {
private static Version targetPlatform(Application application, Change change, Optional<Version> existing,
Supplier<Version> defaultVersion) {
- if (change.isPinned() && change.platform().isPresent())
+ if (change.isPlatformPinned() && change.platform().isPresent())
return change.platform().get();
return max(change.platform(), existing)
@@ -135,6 +135,9 @@ public class Versions {
private static RevisionId targetRevision(Application application, Change change,
Optional<RevisionId> existing) {
+ if (change.isRevisionPinned() && change.revision().isPresent())
+ return change.revision().get();
+
return change.revision()
.or(() -> existing)
.orElseGet(() -> defaultRevision(application));
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java
index ee12c9957b1..e5006ab9785 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java
@@ -82,7 +82,8 @@ public class ApplicationSerializer {
private static final String versionsField = "versions";
private static final String prodVersionsField = "prodVersions";
private static final String devVersionsField = "devVersions";
- private static final String pinnedField = "pinned";
+ private static final String platformPinnedField = "pinned";
+ private static final String revisionPinnedField = "revisionPinned";
private static final String deploymentIssueField = "deploymentIssueId";
private static final String ownershipIssueIdField = "ownershipIssueId";
private static final String ownerField = "confirmedOwner";
@@ -118,6 +119,7 @@ public class ApplicationSerializer {
private static final String riskField = "risk";
private static final String authorEmailField = "authorEmailField";
private static final String deployedDirectlyField = "deployedDirectly";
+ private static final String obsoleteAtField = "obsoleteAt";
private static final String hasPackageField = "hasPackage";
private static final String shouldSkipField = "shouldSkip";
private static final String compileVersionField = "compileVersion";
@@ -265,6 +267,7 @@ public class ApplicationSerializer {
applicationVersion.sourceUrl().ifPresent(url -> object.setString(sourceUrlField, url));
applicationVersion.commit().ifPresent(commit -> object.setString(commitField, commit));
object.setBool(deployedDirectlyField, applicationVersion.isDeployedDirectly());
+ applicationVersion.obsoleteAt().ifPresent(at -> object.setLong(obsoleteAtField, at.toEpochMilli()));
object.setBool(hasPackageField, applicationVersion.hasPackage());
object.setBool(shouldSkipField, applicationVersion.shouldSkip());
applicationVersion.description().ifPresent(description -> object.setString(descriptionField, description));
@@ -295,8 +298,10 @@ public class ApplicationSerializer {
object.setString(versionField, deploying.platform().get().toString());
if (deploying.revision().isPresent())
toSlime(deploying.revision().get(), object);
- if (deploying.isPinned())
- object.setBool(pinnedField, true);
+ if (deploying.isPlatformPinned())
+ object.setBool(platformPinnedField, true);
+ if (deploying.isRevisionPinned())
+ object.setBool(revisionPinnedField, true);
}
private void toSlime(RotationStatus status, Cursor array) {
@@ -487,6 +492,7 @@ public class ApplicationSerializer {
Optional<Instant> buildTime = SlimeUtils.optionalInstant(object.field(buildTimeField));
Optional<String> sourceUrl = SlimeUtils.optionalString(object.field(sourceUrlField));
Optional<String> commit = SlimeUtils.optionalString(object.field(commitField));
+ Optional<Instant> obsoleteAt = SlimeUtils.optionalInstant(object.field(obsoleteAtField));
boolean hasPackage = object.field(hasPackageField).asBool();
boolean shouldSkip = object.field(shouldSkipField).asBool();
Optional<String> description = SlimeUtils.optionalString(object.field(descriptionField));
@@ -494,7 +500,7 @@ public class ApplicationSerializer {
Optional<String> bundleHash = SlimeUtils.optionalString(object.field(bundleHashField));
return new ApplicationVersion(id, sourceRevision, authorEmail, compileVersion, allowedMajor, buildTime,
- sourceUrl, commit, bundleHash, hasPackage, shouldSkip, description, risk);
+ sourceUrl, commit, bundleHash, obsoleteAt, hasPackage, shouldSkip, description, risk);
}
private Optional<SourceRevision> sourceRevisionFromSlime(Inspector object) {
@@ -520,8 +526,10 @@ public class ApplicationSerializer {
change = Change.of(Version.fromString(versionFieldValue.asString()));
if (object.field(applicationBuildNumberField).valid())
change = change.with(revisionFromSlime(object, null));
- if (object.field(pinnedField).asBool())
- change = change.withPin();
+ if (object.field(platformPinnedField).asBool())
+ change = change.withPlatformPin();
+ if (object.field(revisionPinnedField).asBool())
+ change = change.withRevisionPin();
return change;
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
index 81988753621..ded27ee1060 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
@@ -260,11 +260,9 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler {
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/package")) return applicationPackage(path.get("tenant"), path.get("application"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/diff/{number}")) return applicationPackageDiff(path.get("tenant"), path.get("application"), path.get("number"));
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/deploying")) return deploying(path.get("tenant"), path.get("application"), "default", request);
- if (path.matches("/application/v4/tenant/{tenant}/application/{application}/deploying/pin")) return deploying(path.get("tenant"), path.get("application"), "default", request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance")) return applications(path.get("tenant"), Optional.of(path.get("application")), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}")) return instance(path.get("tenant"), path.get("application"), path.get("instance"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/deploying")) return deploying(path.get("tenant"), path.get("application"), path.get("instance"), request);
- if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/deploying/pin")) return deploying(path.get("tenant"), path.get("application"), path.get("instance"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/job")) return JobControllerApiHandlerHelper.jobTypeResponse(controller, appIdFromPath(path), request.getUri());
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/job/{jobtype}")) return JobControllerApiHandlerHelper.runResponse(controller.applications().requireApplication(TenantAndApplicationId.from(path.get("tenant"), path.get("application"))), controller.jobController().runs(appIdFromPath(path), jobTypeFromPath(path)).descendingMap(), Optional.ofNullable(request.getProperty("limit")), request.getUri()); // (((\(✘෴✘)/)))
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/job/{jobtype}/package")) return devApplicationPackage(appIdFromPath(path), jobTypeFromPath(path));
@@ -327,14 +325,18 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler {
if (path.matches("/application/v4/tenant/{tenant}/application/{application}")) return createApplication(path.get("tenant"), path.get("application"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/deploying/platform")) return deployPlatform(path.get("tenant"), path.get("application"), "default", false, request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/deploying/pin")) return deployPlatform(path.get("tenant"), path.get("application"), "default", true, request);
- if (path.matches("/application/v4/tenant/{tenant}/application/{application}/deploying/application")) return deployApplication(path.get("tenant"), path.get("application"), "default", request);
+ if (path.matches("/application/v4/tenant/{tenant}/application/{application}/deploying/platform-pin")) return deployPlatform(path.get("tenant"), path.get("application"), "default", true, request);
+ if (path.matches("/application/v4/tenant/{tenant}/application/{application}/deploying/application-pin")) return deployApplication(path.get("tenant"), path.get("application"), "default", true, request);
+ if (path.matches("/application/v4/tenant/{tenant}/application/{application}/deploying/application")) return deployApplication(path.get("tenant"), path.get("application"), "default", false, request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/key")) return addDeployKey(path.get("tenant"), path.get("application"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/submit")) return submit(path.get("tenant"), path.get("application"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}")) return createInstance(path.get("tenant"), path.get("application"), path.get("instance"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/deploy/{jobtype}")) return jobDeploy(appIdFromPath(path), jobTypeFromPath(path), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/deploying/platform")) return deployPlatform(path.get("tenant"), path.get("application"), path.get("instance"), false, request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/deploying/pin")) return deployPlatform(path.get("tenant"), path.get("application"), path.get("instance"), true, request);
- if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/deploying/application")) return deployApplication(path.get("tenant"), path.get("application"), path.get("instance"), request);
+ if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/deploying/platform-pin")) return deployPlatform(path.get("tenant"), path.get("application"), path.get("instance"), true, request);
+ if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/deploying/application-pin")) return deployApplication(path.get("tenant"), path.get("application"), path.get("instance"), true, request);
+ if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/deploying/application")) return deployApplication(path.get("tenant"), path.get("application"), path.get("instance"), false, request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/submit")) return submit(path.get("tenant"), path.get("application"), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/job/{jobtype}")) return trigger(appIdFromPath(path), jobTypeFromPath(path), request);
if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/job/{jobtype}/pause")) return pause(appIdFromPath(path), jobTypeFromPath(path));
@@ -2059,7 +2061,9 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler {
if ( ! instance.change().isEmpty()) {
instance.change().platform().ifPresent(version -> root.setString("platform", version.toString()));
instance.change().revision().ifPresent(revision -> root.setString("application", revision.toString()));
- root.setBool("pinned", instance.change().isPinned());
+ root.setBool("pinned", instance.change().isPlatformPinned());
+ root.setBool("platform-pinned", instance.change().isPlatformPinned());
+ root.setBool("application-pinned", instance.change().isRevisionPinned());
}
return new SlimeJsonResponse(slime);
}
@@ -2172,7 +2176,7 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler {
.collect(joining(", ")));
Change change = Change.of(version);
if (pin)
- change = change.withPin();
+ change = change.withPlatformPin();
controller.applications().deploymentTrigger().forceChange(id, change, isOperator(request));
response.append("Triggered ").append(change).append(" for ").append(id);
@@ -2181,7 +2185,7 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler {
}
/** Trigger deployment to the last known application package for the given application. */
- private HttpResponse deployApplication(String tenantName, String applicationName, String instanceName, HttpRequest request) {
+ private HttpResponse deployApplication(String tenantName, String applicationName, String instanceName, boolean pin, HttpRequest request) {
ApplicationId id = ApplicationId.from(tenantName, applicationName, instanceName);
Inspector buildField = toSlime(request.getData()).get().field("build");
long build = buildField.valid() ? buildField.asLong() : -1;
@@ -2191,6 +2195,8 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler {
RevisionId revision = build == -1 ? application.get().revisions().last().get().id()
: getRevision(application.get(), build);
Change change = Change.of(revision);
+ if (pin)
+ change = change.withRevisionPin();
controller.applications().deploymentTrigger().forceChange(id, change, isOperator(request));
response.append("Triggered ").append(change).append(" for ").append(id);
});
@@ -2231,7 +2237,7 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler {
return;
}
- ChangesToCancel cancel = ChangesToCancel.valueOf(choice.toUpperCase());
+ ChangesToCancel cancel = ChangesToCancel.valueOf(choice.replaceAll("-", "_").toUpperCase());
controller.applications().deploymentTrigger().cancelChange(id, cancel);
response.append("Changed deployment from '").append(change).append("' to '").append(controller.applications().requireInstance(id).change()).append("' for ").append(id);
});
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java
index 804ae7b7805..9ff8c7df18b 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/JobControllerApiHandlerHelper.java
@@ -312,7 +312,9 @@ class JobControllerApiHandlerHelper {
if ( ! change.isEmpty()) {
change.platform().ifPresent(version -> deployingObject.setString("platform", version.toFullString()));
change.revision().ifPresent(revision -> toSlime(deployingObject.setObject("application"), application.revisions().get(revision)));
- if (change.isPinned()) deployingObject.setBool("pinned", true);
+ if (change.isPlatformPinned()) deployingObject.setBool("pinned", true);
+ if (change.isPlatformPinned()) deployingObject.setBool("platformPinned", true);
+ if (change.isRevisionPinned()) deployingObject.setBool("revisionPinned", true);
}
Cursor latestVersionsObject = stepObject.setObject("latestVersions");
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/deployment/DeploymentApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/deployment/DeploymentApiHandler.java
index 069ee58e9c5..6e5635e8c8c 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/deployment/DeploymentApiHandler.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/deployment/DeploymentApiHandler.java
@@ -171,7 +171,9 @@ public class DeploymentApiHandler extends ThreadedHttpRequestHandler {
instanceObject.setString("application", instance.application().value());
instanceObject.setString("instance", instance.instance().value());
instanceObject.setBool("upgrading", status.application().require(instance.instance()).change().platform().equals(Optional.of(statistics.version())));
- instanceObject.setBool("pinned", status.application().require(instance.instance()).change().isPinned());
+ instanceObject.setBool("pinned", status.application().require(instance.instance()).change().isPlatformPinned());
+ instanceObject.setBool("platformPinned", status.application().require(instance.instance()).change().isPlatformPinned());
+ instanceObject.setBool("revisionPinned", status.application().require(instance.instance()).change().isRevisionPinned());
DeploymentStatus.StepStatus stepStatus = status.instanceSteps().get(instance.instance());
if (stepStatus != null) { // Instance may not have any steps, i.e. an empty deployment spec has been submitted
Readiness platformReadiness = stepStatus.blockedUntil(Change.of(statistics.version()));
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
index a9a6fe602b6..04c8c46e1ef 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
@@ -41,6 +41,7 @@ import com.yahoo.vespa.hosted.controller.application.pkg.ApplicationPackage;
import com.yahoo.vespa.hosted.controller.deployment.ApplicationPackageBuilder;
import com.yahoo.vespa.hosted.controller.deployment.DeploymentContext;
import com.yahoo.vespa.hosted.controller.deployment.DeploymentTester;
+import com.yahoo.vespa.hosted.controller.deployment.JobController;
import com.yahoo.vespa.hosted.controller.deployment.Submission;
import com.yahoo.vespa.hosted.controller.integration.ZoneApiMock;
import com.yahoo.vespa.hosted.controller.notification.Notification;
@@ -106,9 +107,11 @@ public class ControllerTest {
Version version1 = tester.configServer().initialVersion();
var context = tester.newDeploymentContext();
context.submit(applicationPackage);
- assertEquals(ApplicationVersion.from(RevisionId.forProduction(1), DeploymentContext.defaultSourceRevision, "a@b", new Version("6.1"), Instant.ofEpochSecond(1)),
- context.application().revisions().get(context.instance().change().revision().get()),
- "Application version is known from completion of initial job");
+ RevisionId id = RevisionId.forProduction(1);
+ Version compileVersion = new Version("6.1");
+ assertEquals(new ApplicationVersion(id, Optional.of(DeploymentContext.defaultSourceRevision), Optional.of("a@b"), Optional.of(compileVersion), Optional.empty(), Optional.of(Instant.ofEpochSecond(1)), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true, false, Optional.empty(), 0),
+ context.application().revisions().get(context.instance().change().revision().get()),
+ "Application version is known from completion of initial job");
context.runJob(systemTest);
context.runJob(stagingTest);
@@ -220,6 +223,59 @@ public class ControllerTest {
}
@Test
+ void testPackagePruning() {
+ DeploymentContext app = tester.newDeploymentContext().submit().deploy();
+ RevisionId revision1 = app.lastSubmission().get();
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision1.number()));
+
+ app.submit().deploy();
+ RevisionId revision2 = app.lastSubmission().get();
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision1.number()));
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision2.number()));
+
+ // Revision 1 is marked as obsolete now
+ app.submit().deploy();
+ RevisionId revision3 = app.lastSubmission().get();
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision1.number()));
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision2.number()));
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision3.number()));
+
+ // Time advances, and revision 2 is marked as obsolete now
+ tester.clock().advance(JobController.obsoletePackageExpiry);
+ app.submit().deploy();
+ RevisionId revision4 = app.lastSubmission().get();
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision1.number()));
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision2.number()));
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision3.number()));
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision4.number()));
+
+ // Time advances, and revision is now old enough to be pruned
+ tester.clock().advance(Duration.ofMillis(1));
+ app.submit().deploy();
+ RevisionId revision5 = app.lastSubmission().get();
+ assertFalse(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision1.number()));
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision2.number()));
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision3.number()));
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision4.number()));
+ assertTrue(tester.controllerTester().serviceRegistry().applicationStore()
+ .hasBuild(app.instanceId().tenant(), app.instanceId().application(), revision5.number()));
+ }
+
+ @Test
void testGlobalRotationStatus() {
var context = tester.newDeploymentContext();
var zone1 = ZoneId.from("prod", "us-west-1");
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java
index afb92d84f3b..6e5c2458c92 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java
@@ -653,15 +653,21 @@ public class DeploymentTriggerTest {
assertEquals(appVersion1, latestDeployed(app.instance()));
// Downgrading application version.
- tester.deploymentTrigger().forceChange(app.instanceId(), Change.of(appVersion0));
- assertEquals(Change.of(appVersion0), app.instance().change());
+ tester.deploymentTrigger().forceChange(app.instanceId(), Change.of(appVersion0).withRevisionPin());
+ assertEquals(Change.of(appVersion0).withRevisionPin(), app.instance().change());
app.runJob(stagingTest)
- .runJob(productionUsCentral1)
- .runJob(productionUsEast3)
- .runJob(productionUsWest1);
- assertEquals(Change.empty(), app.instance().change());
+ .runJob(productionUsCentral1)
+ .runJob(productionUsEast3)
+ .runJob(productionUsWest1);
+ assertEquals(Change.empty().withRevisionPin(), app.instance().change());
assertEquals(appVersion0, app.instance().deployments().get(productionUsEast3.zone()).revision());
assertEquals(appVersion0, latestDeployed(app.instance()));
+
+ tester.outstandingChangeDeployer().run();
+ assertEquals(Change.empty().withRevisionPin(), app.instance().change());
+ tester.deploymentTrigger().cancelChange(app.instanceId(), ALL);
+ tester.outstandingChangeDeployer().run();
+ assertEquals(Change.of(appVersion1), app.instance().change());
}
@Test
@@ -1239,13 +1245,13 @@ public class DeploymentTriggerTest {
assertEquals(Change.empty(), app.instance().change());
// Application is pinned to previous version, and downgrades to that. Tests are re-run.
- tester.deploymentTrigger().forceChange(app.instanceId(), Change.of(version0).withPin());
+ tester.deploymentTrigger().forceChange(app.instanceId(), Change.of(version0).withPlatformPin());
app.runJob(stagingTest).runJob(productionUsEast3);
tester.clock().advance(Duration.ofMinutes(1));
app.failDeployment(testUsEast3);
tester.clock().advance(Duration.ofMinutes(11)); // Job is cooling down after consecutive failures.
app.runJob(testUsEast3);
- assertEquals(Change.empty().withPin(), app.instance().change());
+ assertEquals(Change.empty().withPlatformPin(), app.instance().change());
// A new upgrade is attempted, and production tests wait for redeployment.
tester.controllerTester().upgradeSystem(version2);
@@ -2234,7 +2240,7 @@ public class DeploymentTriggerTest {
.majorVersion(7)
.compileVersion(version1)
.build());
- tester.deploymentTrigger().forceChange(app.instanceId(), app.instance().change().withPin());
+ tester.deploymentTrigger().forceChange(app.instanceId(), app.instance().change().withPlatformPin());
app.deploy();
assertEquals(version1, tester.jobs().last(app.instanceId(), productionUsEast3).get().versions().targetPlatform());
assertEquals(version1, app.application().revisions().get(tester.jobs().last(app.instanceId(), productionUsEast3).get().versions().targetRevision()).compileVersion().get());
@@ -2251,7 +2257,7 @@ public class DeploymentTriggerTest {
// The new app enters a platform block window, and is pinned to the old platform;
// the new submission overrides both those settings, as the new revision should roll out regardless.
tester.atMondayMorning();
- tester.deploymentTrigger().forceChange(newApp.instanceId(), Change.empty().withPin());
+ tester.deploymentTrigger().forceChange(newApp.instanceId(), Change.empty().withPlatformPin());
newApp.submit(new ApplicationPackageBuilder().compileVersion(version2)
.systemTest()
.blockChange(false, true, "mon", "0-23", "UTC")
@@ -2280,11 +2286,11 @@ public class DeploymentTriggerTest {
tester.upgrader().run();
assertEquals(Change.of(newRevision).with(version1), newApp.instance().change());
- tester.deploymentTrigger().forceChange(newApp.instanceId(), newApp.instance().change().withPin());
+ tester.deploymentTrigger().forceChange(newApp.instanceId(), newApp.instance().change().withPlatformPin());
tester.outstandingChangeDeployer().run();
- assertEquals(Change.of(newRevision).with(version1).withPin(), newApp.instance().change());
+ assertEquals(Change.of(newRevision).with(version1).withPlatformPin(), newApp.instance().change());
tester.upgrader().run();
- assertEquals(Change.of(newRevision).with(version1).withPin(), newApp.instance().change());
+ assertEquals(Change.of(newRevision).with(version1).withPlatformPin(), newApp.instance().change());
newApp.deploy();
assertEquals(version1, tester.jobs().last(newApp.instanceId(), productionUsEast3).get().versions().targetPlatform());
@@ -2381,7 +2387,7 @@ public class DeploymentTriggerTest {
.build()))
.getMessage());
- tester.deploymentTrigger().forceChange(app.instanceId(), Change.of(oldVersion).with(app.application().revisions().last().get().id()).withPin());
+ tester.deploymentTrigger().forceChange(app.instanceId(), Change.of(oldVersion).with(app.application().revisions().last().get().id()).withPlatformPin());
app.deploy();
assertEquals(oldVersion, app.deployment(ZoneId.from("prod", "us-east-3")).version());
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/UpgraderTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/UpgraderTest.java
index 11110d6edaa..96c1d7c545d 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/UpgraderTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/UpgraderTest.java
@@ -5,7 +5,6 @@ import com.yahoo.component.Version;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.zone.ZoneId;
import com.yahoo.test.ManualClock;
-import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType;
import com.yahoo.vespa.hosted.controller.api.integration.deployment.RevisionId;
import com.yahoo.vespa.hosted.controller.api.integration.deployment.RunId;
import com.yahoo.vespa.hosted.controller.application.Change;
@@ -27,7 +26,6 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
-import java.util.OptionalInt;
import java.util.Set;
import java.util.stream.Collectors;
@@ -856,10 +854,10 @@ public class UpgraderTest {
// Create an application with pinned platform version.
var context = tester.newDeploymentContext().submit().deploy();
- tester.deploymentTrigger().forceChange(context.instanceId(), Change.empty().withPin());
+ tester.deploymentTrigger().forceChange(context.instanceId(), Change.empty().withPlatformPin());
assertFalse(context.instance().change().hasTargets());
- assertTrue(context.instance().change().isPinned());
+ assertTrue(context.instance().change().isPlatformPinned());
assertEquals(3, context.instance().deployments().size());
// Application does not upgrade.
@@ -867,21 +865,21 @@ public class UpgraderTest {
tester.controllerTester().upgradeSystem(version1);
tester.upgrader().maintain();
assertFalse(context.instance().change().hasTargets());
- assertTrue(context.instance().change().isPinned());
+ assertTrue(context.instance().change().isPlatformPinned());
// New application package is deployed.
context.submit().deploy();
assertFalse(context.instance().change().hasTargets());
- assertTrue(context.instance().change().isPinned());
+ assertTrue(context.instance().change().isPlatformPinned());
// Application upgrades to new version when pin is removed.
tester.deploymentTrigger().cancelChange(context.instanceId(), PIN);
tester.upgrader().maintain();
assertTrue(context.instance().change().hasTargets());
- assertFalse(context.instance().change().isPinned());
+ assertFalse(context.instance().change().isPlatformPinned());
// Application is pinned to new version, and upgrade is therefore not cancelled, even though confidence is broken.
- tester.deploymentTrigger().forceChange(context.instanceId(), Change.empty().withPin());
+ tester.deploymentTrigger().forceChange(context.instanceId(), Change.empty().withPlatformPin());
tester.upgrader().maintain();
tester.triggerJobs();
assertEquals(version1, context.instance().change().platform().get());
@@ -890,7 +888,7 @@ public class UpgraderTest {
context.runJob(systemTest).runJob(stagingTest).runJob(productionUsCentral1)
.timeOutUpgrade(productionUsWest1);
tester.deploymentTrigger().cancelChange(context.instanceId(), ALL);
- tester.deploymentTrigger().forceChange(context.instanceId(), Change.of(version0).withPin());
+ tester.deploymentTrigger().forceChange(context.instanceId(), Change.of(version0).withPlatformPin());
assertEquals(version0, context.instance().change().platform().get());
// Application downgrades to pinned version.
@@ -913,7 +911,7 @@ public class UpgraderTest {
// Keep app 1 on current version
tester.controller().applications().lockApplicationIfPresent(app1.application().id(), app ->
tester.controller().applications().store(app.with(app1.instance().name(),
- instance -> instance.withChange(instance.change().withPin()))));
+ instance -> instance.withChange(instance.change().withPlatformPin()))));
// New version is released
Version version1 = Version.fromString("6.2");
@@ -935,7 +933,7 @@ public class UpgraderTest {
// App 1 is unpinned and upgrades to latest 6
tester.controller().applications().lockApplicationIfPresent(app1.application().id(), app ->
tester.controller().applications().store(app.with(app1.instance().name(),
- instance -> instance.withChange(instance.change().withoutPin()))));
+ instance -> instance.withChange(instance.change().withoutPlatformPin()))));
tester.upgrader().maintain();
assertEquals(version1,
app1.instance().change().platform().orElseThrow(),
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java
index 589fc25700f..b71d3cf838b 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java
@@ -101,16 +101,17 @@ public class ApplicationSerializerTest {
Optional.empty(),
Optional.of("best commit"),
Optional.of("hash1"),
+ Optional.of(Instant.ofEpochMilli(777)),
true,
false,
Optional.of("~(˘▾˘)~"),
3);
assertEquals("https://github/org/repo/tree/commit1", applicationVersion1.sourceUrl().get());
- ApplicationVersion applicationVersion2 = ApplicationVersion.from(RevisionId.forDevelopment(31, new JobId(id1, DeploymentContext.productionUsEast3)),
- new SourceRevision("repo1", "branch1", "commit1"), "a@b",
- Version.fromString("6.3.1"),
- Instant.ofEpochMilli(496));
+ RevisionId id = RevisionId.forDevelopment(31, new JobId(id1, DeploymentContext.productionUsEast3));
+ SourceRevision source = new SourceRevision("repo1", "branch1", "commit1");
+ Version compileVersion = Version.fromString("6.3.1");
+ ApplicationVersion applicationVersion2 = new ApplicationVersion(id, Optional.of(source), Optional.of("a@b"), Optional.of(compileVersion), Optional.empty(), Optional.of(Instant.ofEpochMilli(496)), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), true, false, Optional.empty(), 0);
Instant activityAt = Instant.parse("2018-06-01T10:15:30.00Z");
deployments.add(new Deployment(zone1, CloudAccount.empty, applicationVersion1.id(), Version.fromString("1.2.3"), Instant.ofEpochMilli(3),
DeploymentMetrics.none, DeploymentActivity.none, QuotaUsage.none, OptionalDouble.empty()));
@@ -143,7 +144,7 @@ public class ApplicationSerializerTest {
Map.of(),
List.of(),
RotationStatus.EMPTY,
- Change.of(Version.fromString("6.7")).withPin()));
+ Change.of(Version.fromString("6.7")).withPlatformPin().withRevisionPin()));
Application original = new Application(TenantAndApplicationId.from(id1),
Instant.now().truncatedTo(ChronoUnit.MILLIS),
@@ -174,6 +175,7 @@ public class ApplicationSerializerTest {
assertEquals(original.revisions().last().get().sourceUrl(), serialized.revisions().last().get().sourceUrl());
assertEquals(original.revisions().last().get().commit(), serialized.revisions().last().get().commit());
assertEquals(original.revisions().last().get().bundleHash(), serialized.revisions().last().get().bundleHash());
+ assertEquals(original.revisions().last().get().obsoleteAt(), serialized.revisions().last().get().obsoleteAt());
assertEquals(original.revisions().last().get().hasPackage(), serialized.revisions().last().get().hasPackage());
assertEquals(original.revisions().last().get().shouldSkip(), serialized.revisions().last().get().shouldSkip());
assertEquals(original.revisions().last().get().description(), serialized.revisions().last().get().description());
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
index 9a34989aeff..76bcbe078ff 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java
@@ -541,24 +541,22 @@ public class ApplicationApiTest extends ControllerContainerTest {
"{\"message\":\"No deployment in progress for tenant1.application1.instance1 at this time\"}");
// POST pinning to a given version to an application
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/pin", POST)
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/platform-pin", POST)
.userIdentity(USER_ID)
.data("6.1.0"),
"{\"message\":\"Triggered pin to 6.1 for tenant1.application1.instance1\"}");
assertTrue(tester.controller().auditLogger().readLog().entries().stream()
- .anyMatch(entry -> entry.resource().equals("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/pin?")),
+ .anyMatch(entry -> entry.resource().equals("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/platform-pin?")),
"Action is logged to audit log");
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying", GET)
- .userIdentity(USER_ID), "{\"platform\":\"6.1\",\"pinned\":true}");
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/pin", GET)
- .userIdentity(USER_ID), "{\"platform\":\"6.1\",\"pinned\":true}");
+ .userIdentity(USER_ID), "{\"platform\":\"6.1\",\"pinned\":true,\"platform-pinned\":true,\"application-pinned\":false}");
// DELETE only the pin to a given version
- tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/pin", DELETE)
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/platform-pin", DELETE)
.userIdentity(USER_ID),
"{\"message\":\"Changed deployment from 'pin to 6.1' to 'upgrade to 6.1' for tenant1.application1.instance1\"}");
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying", GET)
- .userIdentity(USER_ID), "{\"platform\":\"6.1\",\"pinned\":false}");
+ .userIdentity(USER_ID), "{\"platform\":\"6.1\",\"pinned\":false,\"platform-pinned\":false,\"application-pinned\":false}");
// POST pinning again
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/pin", POST)
@@ -566,14 +564,14 @@ public class ApplicationApiTest extends ControllerContainerTest {
.data("6.1"),
"{\"message\":\"Triggered pin to 6.1 for tenant1.application1.instance1\"}");
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying", GET)
- .userIdentity(USER_ID), "{\"platform\":\"6.1\",\"pinned\":true}");
+ .userIdentity(USER_ID), "{\"platform\":\"6.1\",\"pinned\":true,\"platform-pinned\":true,\"application-pinned\":false}");
// DELETE only the version, but leave the pin
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/platform", DELETE)
.userIdentity(USER_ID),
"{\"message\":\"Changed deployment from 'pin to 6.1' to 'pin to current platform' for tenant1.application1.instance1\"}");
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying", GET)
- .userIdentity(USER_ID), "{\"pinned\":true}");
+ .userIdentity(USER_ID), "{\"pinned\":true,\"platform-pinned\":true,\"application-pinned\":false}");
// DELETE also the pin to a given version
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/pin", DELETE)
@@ -582,6 +580,32 @@ public class ApplicationApiTest extends ControllerContainerTest {
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying", GET)
.userIdentity(USER_ID), "{}");
+ // POST pinning to a given revision to an application
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/application-pin", POST)
+ .userIdentity(USER_ID)
+ .data(""),
+ "{\"message\":\"Triggered pin to build 1 for tenant1.application1.instance1\"}");
+ assertTrue(tester.controller().auditLogger().readLog().entries().stream()
+ .anyMatch(entry -> entry.resource().equals("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/application-pin?")),
+ "Action is logged to audit log");
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying", GET)
+ .userIdentity(USER_ID), "{\"application\":\"build 1\",\"pinned\":false,\"platform-pinned\":false,\"application-pinned\":true}");
+
+ // DELETE only the pin to a given revision
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/application-pin", DELETE)
+ .userIdentity(USER_ID),
+ "{\"message\":\"Changed deployment from 'pin to build 1' to 'revision change to build 1' for tenant1.application1.instance1\"}");
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying", GET)
+ .userIdentity(USER_ID), "{\"application\":\"build 1\",\"pinned\":false,\"platform-pinned\":false,\"application-pinned\":false}");
+
+ // DELETE deploying to a given revision
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying/application", DELETE)
+ .userIdentity(USER_ID),
+ "{\"message\":\"Changed deployment from 'revision change to build 1' to 'no change' for tenant1.application1.instance1\"}");
+ tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/deploying", GET)
+ .userIdentity(USER_ID), "{}");
+
+
// POST a pause to a production job
tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/job/production-us-west-1/pause", POST)
.userIdentity(USER_ID),
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment-overview.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment-overview.json
index ec6ccf3ecf2..0b7c64c72a5 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment-overview.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment-overview.json
@@ -48,6 +48,14 @@
"sourceUrl": "repository1/tree/commit1",
"commit": "commit1"
}
+ },
+ {
+ "application": {
+ "build": 1,
+ "compileVersion": "6.1.0",
+ "sourceUrl": "repository1/tree/commit1",
+ "commit": "commit1"
+ }
}
],
"blockers": [ ]
@@ -594,6 +602,14 @@
"sourceUrl": "repository1/tree/commit1",
"commit": "commit1"
}
+ },
+ {
+ "application": {
+ "build": 1,
+ "compileVersion": "6.1.0",
+ "sourceUrl": "repository1/tree/commit1",
+ "commit": "commit1"
+ }
}
],
"blockers": [ ]
@@ -709,6 +725,15 @@
"description": "my best commit yet",
"risk": 9001,
"deployable": false
+ },
+ {
+ "build": 1,
+ "compileVersion": "6.1.0",
+ "sourceUrl": "repository1/tree/commit1",
+ "commit": "commit1",
+ "description": "my best commit yet",
+ "risk": 9001,
+ "deployable": true
}
]
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/deployment/responses/root.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/deployment/responses/root.json
index a1f386d51a7..ac43fbf2a80 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/deployment/responses/root.json
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/deployment/responses/root.json
@@ -37,6 +37,8 @@
"instance": "default",
"upgrading": false,
"pinned": false,
+ "platformPinned": false,
+ "revisionPinned": false,
"upgradePolicy": "default",
"compileVersion": "6.1.0",
"jobs": [
@@ -78,6 +80,8 @@
"instance": "i2",
"upgrading": false,
"pinned": false,
+ "platformPinned": false,
+ "revisionPinned": false,
"upgradePolicy": "default",
"compileVersion": "6.1.0",
"jobs": [
@@ -179,6 +183,8 @@
"instance": "default",
"upgrading": true,
"pinned": false,
+ "platformPinned": false,
+ "revisionPinned": false,
"upgradePolicy": "default",
"compileVersion": "6.1.0",
"jobs": [
@@ -249,6 +255,8 @@
"instance": "i1",
"upgrading": false,
"pinned": false,
+ "platformPinned": false,
+ "revisionPinned": false,
"upgradePolicy": "default",
"compileVersion": "6.1.0",
"jobs": [
@@ -309,6 +317,8 @@
"instance": "i2",
"upgrading": true,
"pinned": false,
+ "platformPinned": false,
+ "revisionPinned": false,
"upgradePolicy": "default",
"compileVersion": "6.1.0",
"jobs": [
diff --git a/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java b/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java
index 110564bea46..795f8e93187 100644
--- a/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java
+++ b/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java
@@ -5,7 +5,6 @@ import com.fasterxml.jackson.core.JsonGenerator;
import com.yahoo.document.DataType;
import com.yahoo.document.DocumentId;
import com.yahoo.document.Field;
-import com.yahoo.document.PositionDataType;
import com.yahoo.document.PrimitiveDataType;
import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.BoolFieldValue;
@@ -41,7 +40,6 @@ import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Iterator;
import java.util.Map;
-import java.util.Set;
/**
* @author Steinar Knutsen
@@ -49,7 +47,7 @@ import java.util.Set;
*/
public class JsonSerializationHelper {
- private final static Base64.Encoder base64Encoder = Base64.getEncoder(); // Important: _basic_ format
+ private final static Base64.Encoder base64Encoder = Base64.getEncoder().withoutPadding(); // Important: _basic_ format
static class JsonSerializationException extends RuntimeException {
public JsonSerializationException(Exception base) {
@@ -166,8 +164,7 @@ public class JsonSerializationHelper {
public static void serializeStructField(FieldWriter fieldWriter, JsonGenerator generator, FieldBase field, Struct value) {
DataType dt = value.getDataType();
- if (dt instanceof GeoPosType) {
- var gpt = (GeoPosType)dt;
+ if (dt instanceof GeoPosType gpt) {
if (gpt.renderJsonAsVespa8()) {
serializeGeoPos(generator, field, value, gpt);
return;
diff --git a/document/src/main/java/com/yahoo/document/serialization/DocumentUpdateFlags.java b/document/src/main/java/com/yahoo/document/serialization/DocumentUpdateFlags.java
index e3510676148..11ded80ed2a 100644
--- a/document/src/main/java/com/yahoo/document/serialization/DocumentUpdateFlags.java
+++ b/document/src/main/java/com/yahoo/document/serialization/DocumentUpdateFlags.java
@@ -23,7 +23,7 @@ public class DocumentUpdateFlags {
}
public void setCreateIfNonExistent(boolean value) {
flags &= ~1; // clear flag
- flags |= value ? 1 : 0; // set flag
+ flags |= value ? (byte)1 : (byte)0; // set flag
}
public int injectInto(int value) {
return extractValue(value) | (flags << 28);
diff --git a/document/src/main/java/com/yahoo/document/serialization/XmlSerializationHelper.java b/document/src/main/java/com/yahoo/document/serialization/XmlSerializationHelper.java
index 9c1df0cd6c7..d35693f785f 100644
--- a/document/src/main/java/com/yahoo/document/serialization/XmlSerializationHelper.java
+++ b/document/src/main/java/com/yahoo/document/serialization/XmlSerializationHelper.java
@@ -34,6 +34,8 @@ import java.util.Map;
@SuppressWarnings("removal")
public class XmlSerializationHelper {
+ private final static Base64.Encoder base64Encoder = Base64.getEncoder().withoutPadding();
+
public static void printArrayXml(Array array, XmlStream xml) {
List<FieldValue> lst = array.getValues();
for (FieldValue value : lst) {
@@ -98,7 +100,7 @@ public class XmlSerializationHelper {
public static void printRawXml(Raw r, XmlStream xml) {
xml.addAttribute("binaryencoding", "base64");
- xml.addContent(Base64.getEncoder().encodeToString(r.getByteBuffer().array()));
+ xml.addContent(base64Encoder.encodeToString(r.getByteBuffer().array()));
}
public static void printStringXml(StringFieldValue s, XmlStream xml) {
@@ -106,7 +108,7 @@ public class XmlSerializationHelper {
if (containsNonPrintableCharactersString(content)) {
byte[] bytecontent = Utf8.toBytes(content);
xml.addAttribute("binaryencoding", "base64");
- xml.addContent(Base64.getEncoder().encodeToString(bytecontent));
+ xml.addContent(base64Encoder.encodeToString(bytecontent));
} else {
xml.addContent(content);
}
diff --git a/document/src/test/java/com/yahoo/document/DocumentTestCase.java b/document/src/test/java/com/yahoo/document/DocumentTestCase.java
index 33b77cb1878..4470865b636 100644
--- a/document/src/test/java/com/yahoo/document/DocumentTestCase.java
+++ b/document/src/test/java/com/yahoo/document/DocumentTestCase.java
@@ -52,7 +52,7 @@ public class DocumentTestCase extends DocumentTestCaseBase {
" <mailid>emailfromalicetobob&amp;someone</mailid>\n" +
" <date>-2013512400</date>\n" +
" <attachmentcount>2</attachmentcount>\n" +
- " <rawfield binaryencoding=\"base64\">AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiYw==</rawfield>\n";
+ " <rawfield binaryencoding=\"base64\">AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiYw</rawfield>\n";
private static final String SERTEST_DOC_AS_XML_WEIGHT1 =
" <weightedfield>\n" +
diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
index 08a5c9a124c..af7469de31b 100644
--- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
+++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
@@ -504,7 +504,7 @@ public class DocumentUpdateJsonSerializerTest {
" 'update': 'DOCUMENT_ID',",
" 'fields': {",
" 'raw_field': {",
- " 'assign': 'RG9uJ3QgYmVsaWV2ZSBoaXMgbGllcw=='",
+ " 'assign': 'RG9uJ3QgYmVsaWV2ZSBoaXMgbGllcw'",
" }",
" }",
"}"
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 0c130ab9a42..a761a9adfb6 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -668,7 +668,7 @@ public class JsonReaderTestCase {
@Test
public void testRaw() throws IOException {
String base64 = new String(new JsonStringEncoder().quoteAsString(
- Base64.getEncoder().encodeToString(Utf8.toBytes("smoketest"))));
+ Base64.getEncoder().withoutPadding().encodeToString(Utf8.toBytes("smoketest"))));
String s = fieldStringFromBase64RawContent(base64);
assertEquals("smoketest", s);
}
diff --git a/document/src/test/java/com/yahoo/document/json/JsonWriterTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonWriterTestCase.java
index eab33afc3e4..4f15a2fe368 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonWriterTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonWriterTestCase.java
@@ -291,7 +291,7 @@ public class JsonWriterTestCase {
String payload = new String(
new JsonStringEncoder().quoteAsString(
"c3RyaW5nIGxvbmcgZW5vdWdoIHRvIGVtaXQgbW9yZSB0aGFuIDc2IGJhc2U2NCBjaGFyYWN0ZXJzIGFuZC" +
- "B3aGljaCBzaG91bGQgY2VydGFpbmx5IG5vdCBiZSBuZXdsaW5lLWRlbGltaXRlZCE="));
+ "B3aGljaCBzaG91bGQgY2VydGFpbmx5IG5vdCBiZSBuZXdsaW5lLWRlbGltaXRlZCE"));
String docId = "id:unittest:testraw::whee";
diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml
index b58533b32e9..0ae2c68e6a8 100644
--- a/fat-model-dependencies/pom.xml
+++ b/fat-model-dependencies/pom.xml
@@ -96,6 +96,10 @@
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>com.theokanning.openai-gpt3-java</groupId>
+ <artifactId>service</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
index d55b58a1728..ea17d9967ae 100644
--- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
+++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
@@ -398,6 +398,12 @@ public class Flags {
"Whether to enable CrowdStrike.", "Takes effect on next host admin tick",
HOSTNAME);
+ public static final UnboundBooleanFlag ALLOW_MORE_THAN_ONE_CONTENT_GROUP_DOWN = defineFeatureFlag(
+ "allow-more-than-one-content-group-down", false, List.of("hmusum"), "2023-04-14", "2023-06-14",
+ "Whether to enable possible configuration of letting more than one content group down",
+ "Takes effect at redeployment",
+ HOSTNAME);
+
/** WARNING: public for testing: All flags should be defined in {@link Flags}. */
public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, List<String> owners,
String createdAt, String expiresAt, String description,
diff --git a/jdisc_core/src/test/java/com/yahoo/jdisc/core/ExportPackagesIT.java b/jdisc_core/src/test/java/com/yahoo/jdisc/core/ExportPackagesIT.java
index e9aba0893f9..7ec3406be1f 100644
--- a/jdisc_core/src/test/java/com/yahoo/jdisc/core/ExportPackagesIT.java
+++ b/jdisc_core/src/test/java/com/yahoo/jdisc/core/ExportPackagesIT.java
@@ -62,8 +62,8 @@ public class ExportPackagesIT {
String expectedValue = expectedProperties.getProperty(ExportPackages.EXPORT_PACKAGES);
assertNotNull(expectedValue, "Missing exportPackages property in file.");
- Set<String> actualPackages = getPackages(actualValue);
- Set<String> expectedPackages = getPackages(expectedValue);
+ Set<String> actualPackages = removeNewPackageOnJava20(removeJavaVersion(getPackages(actualValue)));
+ Set<String> expectedPackages = removeNewPackageOnJava20(removeJavaVersion(getPackages(expectedValue)));
if (!actualPackages.equals(expectedPackages)) {
StringBuilder message = getDiff(actualPackages, expectedPackages);
message.append("\n\nIf this test fails due to an intentional change in exported packages, run the following command:\n")
@@ -73,6 +73,14 @@ public class ExportPackagesIT {
}
}
+ private static Set<String> removeJavaVersion(Set<String> packages) {
+ return packages.stream().map(p -> p.replaceAll(".JavaSE_\\d+", "")).collect(Collectors.toSet());
+ }
+
+ private static Set<String> removeNewPackageOnJava20(Set<String> packages) {
+ return packages.stream().filter(p -> ! p.contains("java.lang.foreign")).collect(Collectors.toSet());
+ }
+
private static StringBuilder getDiff(Set<String> actual, Set<String> expected) {
StringBuilder sb = new StringBuilder();
Set<String> onlyInActual = onlyInSet1(actual, expected);
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index c27ed9d2c31..c96441f11a7 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -111,6 +111,11 @@
</dependency>
<dependency>
+ <groupId>com.theokanning.openai-gpt3-java</groupId>
+ <artifactId>service</artifactId>
+ </dependency>
+
+ <dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
@@ -146,6 +151,18 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
+ <configuration>
+ <!--
+ openai-gpt3-java depends on a different Jackson version than the one we provide,
+ which leads to warnings, so we must disable error on warnings.
+ -->
+ <compilerArgs>
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-rawtypes</arg>
+ <arg>-Xlint:-unchecked</arg>
+ <arg>-Xlint:-serial</arg>
+ </compilerArgs>
+ </configuration>
</plugin>
<plugin>
<groupId>com.github.os72</groupId>
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index b40e2b5be72..bf56d233f89 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -32,10 +32,9 @@ import java.util.Map;
*/
public class BertBaseEmbedder extends AbstractComponent implements Embedder {
- private final static int TOKEN_CLS = 101; // [CLS]
- private final static int TOKEN_SEP = 102; // [SEP]
-
private final int maxTokens;
+ private final int startSequenceToken;
+ private final int endSequenceToken;
private final String inputIdsName;
private final String attentionMaskName;
private final String tokenTypeIdsName;
@@ -48,6 +47,8 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
@Inject
public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) {
maxTokens = config.transformerMaxTokens();
+ startSequenceToken = config.transformerStartSequenceToken();
+ endSequenceToken = config.transformerEndSequenceToken();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
tokenTypeIdsName = config.transformerTokenTypeIds();
@@ -98,7 +99,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
if (!type.dimensions().get(0).isIndexed()) {
throw new IllegalArgumentException("Error in embedding to type '" + type + "': dimension should be indexed.");
}
- List<Integer> tokens = embedWithSeperatorTokens(text, context, maxTokens);
+ List<Integer> tokens = embedWithSeparatorTokens(text, context, maxTokens);
return embedTokens(tokens, type);
}
@@ -109,6 +110,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
Tensor attentionMask = createAttentionMask(inputSequence);
Tensor tokenTypeIds = createTokenTypeIds(inputSequence);
+
Map<String, Tensor> inputs;
if (!"".equals(tokenTypeIdsName)) {
inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
@@ -138,14 +140,14 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
- private List<Integer> embedWithSeperatorTokens(String text, Context context, int maxLength) {
+ private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) {
List<Integer> tokens = new ArrayList<>();
- tokens.add(TOKEN_CLS);
+ tokens.add(startSequenceToken);
tokens.addAll(embed(text, context));
- tokens.add(TOKEN_SEP);
+ tokens.add(endSequenceToken);
if (tokens.size() > maxLength) {
tokens = tokens.subList(0, maxLength-1);
- tokens.add(TOKEN_SEP);
+ tokens.add(endSequenceToken);
}
return tokens;
}
diff --git a/model-integration/src/main/java/ai/vespa/llm/Completion.java b/model-integration/src/main/java/ai/vespa/llm/Completion.java
new file mode 100644
index 00000000000..5f483a65186
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/Completion.java
@@ -0,0 +1,41 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+import java.util.Objects;
+
+/**
+ * A completion from a language model.
+ *
+ * @author bratseth
+ */
+@Beta
+public record Completion(String text, FinishReason finishReason) {
+
+ public enum FinishReason {
+
+ /** The maximum length of a completion was reached. */
+ length,
+
+ /** The completion is the predicted ending of the prompt. */
+ stop
+
+ }
+
+ public Completion(String text, FinishReason finishReason) {
+ this.text = Objects.requireNonNull(text);
+ this.finishReason = Objects.requireNonNull(finishReason);
+ }
+
+ /** Returns the generated text completion. */
+ public String text() { return text; }
+
+ /** Returns the reason this completion ended. */
+ public FinishReason finishReason() { return finishReason; }
+
+ public static Completion from(String text) {
+ return new Completion(text, FinishReason.stop);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/Generator.java
index 973b5ac2899..6b60041947b 100644
--- a/model-integration/src/main/java/ai/vespa/llm/Generator.java
+++ b/model-integration/src/main/java/ai/vespa/llm/Generator.java
@@ -13,6 +13,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.api.annotations.Beta;
import java.util.ArrayList;
import java.util.List;
@@ -27,6 +28,7 @@ import java.util.Map;
*
* @author lesters
*/
+@Beta
public class Generator extends AbstractComponent {
private final static int TOKEN_EOS = 1; // end of sequence
diff --git a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java
index 743bb7c2f27..8b490a733dd 100644
--- a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java
@@ -1,5 +1,8 @@
package ai.vespa.llm;
+import com.yahoo.api.annotations.Beta;
+
+@Beta
public class GeneratorOptions {
public enum SearchMethod {
diff --git a/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java
new file mode 100644
index 00000000000..0739162c5ee
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java
@@ -0,0 +1,18 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+import java.util.List;
+
+/**
+ * Interface to language models.
+ *
+ * @author bratseth
+ */
+@Beta
+public interface LanguageModel {
+
+ List<Completion> complete(Prompt prompt);
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/Prompt.java b/model-integration/src/main/java/ai/vespa/llm/Prompt.java
new file mode 100644
index 00000000000..77093d5e21b
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/Prompt.java
@@ -0,0 +1,23 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+/**
+ * A prompt that can be given to a large language model to generate a completion.
+ *
+ * @author bratseth
+ */
+@Beta
+public abstract class Prompt {
+
+ public abstract String asString();
+
+ /** Returns a new prompt with the text of the given completion appended. */
+ public Prompt append(Completion completion) {
+ return append(completion.text());
+ }
+
+ public abstract Prompt append(String text);
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java b/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java
new file mode 100644
index 00000000000..0af8388dfb1
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java
@@ -0,0 +1,43 @@
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+import java.util.Objects;
+
+/**
+ * A prompt which just consists of a string.
+ *
+ * @author bratseth
+ */
+@Beta
+public class StringPrompt extends Prompt {
+
+ private final String string;
+
+ private StringPrompt(String string) {
+ this.string = Objects.requireNonNull(string);
+ }
+
+ @Override
+ public String asString() { return string; }
+
+ @Override
+ public StringPrompt append(String text) {
+ return StringPrompt.from(string + text);
+ }
+
+ @Override
+ public StringPrompt append(Completion completion) {
+ return append(completion.text());
+ }
+
+ @Override
+ public String toString() {
+ return string;
+ }
+
+ public static StringPrompt from(String string) {
+ return new StringPrompt(string);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java b/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java
new file mode 100644
index 00000000000..3f4475b2482
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java
@@ -0,0 +1,84 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.client;
+
+import ai.vespa.llm.Completion;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.Prompt;
+import com.theokanning.openai.OpenAiHttpException;
+import com.theokanning.openai.completion.CompletionRequest;
+import com.theokanning.openai.service.OpenAiService;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.yolean.Exceptions;
+
+import java.util.List;
+
+/**
+ * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/.
+ *
+ * @author bratseth
+ */
+@Beta
+public class OpenAiClient implements LanguageModel {
+
+ private final OpenAiService openAiService;
+ private final String model;
+ private final boolean echo;
+
+ private OpenAiClient(Builder builder) {
+ openAiService = new OpenAiService(builder.token);
+ this.model = builder.model;
+ this.echo = builder.echo;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt) {
+ try {
+ CompletionRequest completionRequest = CompletionRequest.builder()
+ .prompt(prompt.asString())
+ .model(model)
+ .echo(echo)
+ .build();
+ return openAiService.createCompletion(completionRequest).getChoices().stream()
+ .map(c -> new Completion(c.getText(), toFinishReason(c.getFinish_reason()))).toList();
+ }
+ catch (OpenAiHttpException e) {
+ throw new RuntimeException(Exceptions.toMessageString(e));
+ }
+ }
+
+ private Completion.FinishReason toFinishReason(String finishReasonString) {
+ return switch(finishReasonString) {
+ case "length" -> Completion.FinishReason.length;
+ case "stop" -> Completion.FinishReason.stop;
+ default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'");
+ };
+ }
+
+ public static class Builder {
+
+ private final String token;
+ private String model = "text-davinci-003";
+ private boolean echo = false;
+
+ public Builder(String token) {
+ this.token = token;
+ }
+
+ /** One of the language models listed at https://platform.openai.com/docs/models */
+ public Builder model(String model) {
+ this.model = model;
+ return this;
+ }
+
+ public Builder echo(boolean echo) {
+ this.echo = echo;
+ return this;
+ }
+
+ public OpenAiClient build() {
+ return new OpenAiClient(this);
+ }
+
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java
new file mode 100644
index 00000000000..54b085a451c
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java
@@ -0,0 +1,44 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.test;
+
+import ai.vespa.llm.Completion;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.Prompt;
+import com.yahoo.api.annotations.Beta;
+
+import java.util.List;
+import java.util.function.Function;
+
+/**
+ * @author bratseth
+ */
+@Beta
+public class MockLanguageModel implements LanguageModel {
+
+ private final Function<Prompt, List<Completion>> completer;
+
+ public MockLanguageModel(Builder builder) {
+ completer = builder.completer;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt) {
+ return completer.apply(prompt);
+ }
+
+ public static class Builder {
+
+ private Function<Prompt, List<Completion>> completer = prompt -> List.of(Completion.from(""));
+
+ public Builder completer(Function<Prompt, List<Completion>> completer) {
+ this.completer = completer;
+ return this;
+ }
+
+ public Builder() {}
+
+ public MockLanguageModel build() { return new MockLanguageModel(this); }
+
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 1ed219a8560..a980ca984ec 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -7,6 +7,9 @@ import ai.onnxruntime.OrtSession;
import java.util.Objects;
+import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.PARALLEL;
+import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
+
/**
* Session options for ONNX Runtime evaluation
*
@@ -24,9 +27,10 @@ public class OnnxEvaluatorOptions {
public OnnxEvaluatorOptions() {
// Defaults:
optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT;
- executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
- interOpThreads = 1;
- intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4));
+ executionMode = SEQUENTIAL;
+ int quarterVcpu = Math.max(1, (int) Math.ceil(Runtime.getRuntime().availableProcessors() / 4d));
+ interOpThreads = quarterVcpu;
+ intraOpThreads = quarterVcpu;
gpuDeviceNumber = -1;
gpuDeviceRequired = false;
}
@@ -35,7 +39,7 @@ public class OnnxEvaluatorOptions {
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(optimizationLevel);
options.setExecutionMode(executionMode);
- options.setInterOpNumThreads(interOpThreads);
+ options.setInterOpNumThreads(executionMode == PARALLEL ? interOpThreads : 1);
options.setIntraOpNumThreads(intraOpThreads);
if (loadCuda) {
options.addCUDA(gpuDeviceNumber);
@@ -47,7 +51,7 @@ public class OnnxEvaluatorOptions {
if ("parallel".equalsIgnoreCase(mode)) {
executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
} else if ("sequential".equalsIgnoreCase(mode)) {
- executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
+ executionMode = SEQUENTIAL;
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
index ceb9a27924d..eee60d56c55 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
@@ -96,14 +96,14 @@ public class OrderedTensorType {
* so that they are correctly laid out in memory for Vespa.
* Used when importing tensors.
*/
- public int toDirectIndex(int index) {
+ public long toDirectIndex(int index) {
if (dimensions.size() == 0) {
return 0;
}
if (dimensionMap == null) {
throw new IllegalArgumentException("Dimension map is not available");
}
- int directIndex = 0;
+ long directIndex = 0;
long rest = index;
for (int i = 0; i < dimensions.size(); ++i) {
long address = rest / innerSizesOriginal[i];
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
index 14d953eeef9..ef42d81e1fe 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
@@ -17,6 +17,10 @@ transformerInputIds string default=input_ids
transformerAttentionMask string default=attention_mask
transformerTokenTypeIds string default=token_type_ids
+# special token ids
+transformerStartSequenceToken int default=101
+transformerEndSequenceToken int default=102
+
# Output name
transformerOutput string default=output_0
diff --git a/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java b/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java
new file mode 100644
index 00000000000..30b1c8c2fb1
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java
@@ -0,0 +1,37 @@
+package ai.vespa.llm;
+
+import ai.vespa.llm.test.MockLanguageModel;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.function.Function;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/**
+ * Tests completion with a mock completer.
+ *
+ * @author bratseth
+ */
+public class CompletionTest {
+
+ @Test
+ public void testCompletion() {
+ Function<Prompt, List<Completion>> completer = in ->
+ switch (in.asString()) {
+ case "Complete this: " -> List.of(Completion.from("The completion"));
+ default -> throw new RuntimeException("Cannot complete '" + in + "'");
+ };
+ var llm = new MockLanguageModel.Builder().completer(completer).build();
+
+ String input = "Complete this: ";
+ StringPrompt prompt = StringPrompt.from(input);
+ for (int i = 0; i < 10; i++) {
+ var completion = llm.complete(prompt).get(0);
+ prompt = prompt.append(completion);
+ if (completion.finishReason() == Completion.FinishReason.stop) break;
+ }
+ assertEquals("Complete this: The completion", prompt.asString());
+ }
+
+}
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/reports/DropDocumentsReport.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/reports/DropDocumentsReport.java
new file mode 100644
index 00000000000..0d88f10ebf9
--- /dev/null
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/reports/DropDocumentsReport.java
@@ -0,0 +1,55 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.node.admin.configserver.noderepository.reports;
+
+import com.fasterxml.jackson.annotation.JsonGetter;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+/**
+ * @author freva
+ */
+@JsonIgnoreProperties(ignoreUnknown = true)
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public class DropDocumentsReport extends BaseReport {
+ private static final String REPORT_ID = "dropDocuments";
+ private static final String DROPPED_AT_FIELD = "droppedAt";
+ private static final String READIED_AT_FIELD = "readiedAt";
+ private static final String STARTED_AT_FIELD = "startedAt";
+
+ private final Long droppedAt;
+ private final Long readiedAt;
+ private final Long startedAt;
+
+ public DropDocumentsReport(@JsonProperty(CREATED_FIELD) Long createdMillisOrNull,
+ @JsonProperty(DROPPED_AT_FIELD) Long droppedAtOrNull,
+ @JsonProperty(READIED_AT_FIELD) Long readiedAtOrNull,
+ @JsonProperty(STARTED_AT_FIELD) Long startedAtOrNull) {
+ super(createdMillisOrNull, null);
+ this.droppedAt = droppedAtOrNull;
+ this.readiedAt = readiedAtOrNull;
+ this.startedAt = startedAtOrNull;
+ }
+
+ @JsonGetter(DROPPED_AT_FIELD)
+ public Long droppedAt() { return droppedAt; }
+
+ @JsonGetter(READIED_AT_FIELD)
+ public Long readiedAt() { return readiedAt; }
+
+ @JsonGetter(STARTED_AT_FIELD)
+ public Long startedAt() { return startedAt; }
+
+ public DropDocumentsReport withDroppedAt(long droppedAt) {
+ return new DropDocumentsReport(getCreatedMillisOrNull(), droppedAt, readiedAt, startedAt);
+ }
+
+ public DropDocumentsReport withStartedAt(long startedAt) {
+ return new DropDocumentsReport(getCreatedMillisOrNull(), droppedAt, readiedAt, startedAt);
+ }
+
+ public static String reportId() {
+ return REPORT_ID;
+ }
+
+}
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java
index d22fd667202..3fb9c73367d 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java
@@ -122,7 +122,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
Files.createDirectories(privateKeyFile.getParent());
Files.createDirectories(certificateFile.getParent());
Files.createDirectories(identityDocumentFile.getParent());
- registerIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, identityType);
+ registerIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, identityType, athenzIdentity);
return true;
}
@@ -132,11 +132,11 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
var doc = EntityBindingsMapper.readSignedIdentityDocumentFromFile(identityDocumentFile);
if (doc.outdated()) {
context.log(logger, "Identity document is outdated (version=%d)", doc.documentVersion());
- registerIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, identityType);
+ registerIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, identityType, athenzIdentity);
return true;
} else if (isCertificateExpired(expiry, now)) {
context.log(logger, "Certificate has expired (expiry=%s)", expiry.toString());
- registerIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, identityType);
+ registerIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, identityType, athenzIdentity);
return true;
}
@@ -150,7 +150,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
return false;
} else {
lastRefreshAttempt.put(context.containerName(), now);
- refreshIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, doc, identityType);
+ refreshIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, doc, identityType, athenzIdentity);
return true;
}
}
@@ -198,12 +198,12 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
now)) > 0;
}
- private void registerIdentity(NodeAgentContext context, ContainerPath privateKeyFile, ContainerPath certificateFile, ContainerPath identityDocumentFile, IdentityType identityType) {
+ private void registerIdentity(NodeAgentContext context, ContainerPath privateKeyFile, ContainerPath certificateFile, ContainerPath identityDocumentFile, IdentityType identityType, AthenzIdentity identity) {
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA);
SignedIdentityDocument doc = signedIdentityDocument(context, identityType);
CsrGenerator csrGenerator = new CsrGenerator(certificateDnsSuffix, doc.providerService().getFullName());
Pkcs10Csr csr = csrGenerator.generateInstanceCsr(
- context.identity(), doc.providerUniqueId(), doc.ipAddresses(), doc.clusterType(), keyPair);
+ identity, doc.providerUniqueId(), doc.ipAddresses(), doc.clusterType(), keyPair);
// Allow all zts hosts while removing SIS
HostnameVerifier ztsHostNameVerifier = (hostname, sslSession) -> true;
@@ -211,7 +211,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
InstanceIdentity instanceIdentity =
ztsClient.registerInstance(
doc.providerService(),
- context.identity(),
+ identity,
EntityBindingsMapper.toAttestationData(doc),
csr);
EntityBindingsMapper.writeSignedIdentityDocumentToFile(identityDocumentFile, doc);
@@ -230,11 +230,11 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
.orElse(ztsEndpoint);
}
private void refreshIdentity(NodeAgentContext context, ContainerPath privateKeyFile, ContainerPath certificateFile,
- ContainerPath identityDocumentFile, SignedIdentityDocument doc, IdentityType identityType) {
+ ContainerPath identityDocumentFile, SignedIdentityDocument doc, IdentityType identityType, AthenzIdentity identity) {
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA);
CsrGenerator csrGenerator = new CsrGenerator(certificateDnsSuffix, doc.providerService().getFullName());
Pkcs10Csr csr = csrGenerator.generateInstanceCsr(
- context.identity(), doc.providerUniqueId(), doc.ipAddresses(), doc.clusterType(), keyPair);
+ identity, doc.providerUniqueId(), doc.ipAddresses(), doc.clusterType(), keyPair);
SSLContext containerIdentitySslContext = new SslContextBuilder().withKeyStore(privateKeyFile, certificateFile)
.withTrustStore(ztsTrustStorePath)
@@ -247,7 +247,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
InstanceIdentity instanceIdentity =
ztsClient.refreshInstance(
doc.providerService(),
- context.identity(),
+ identity,
doc.providerUniqueId().asDottedString(),
csr);
writePrivateKeyAndCertificate(privateKeyFile, keyPair.getPrivate(), certificateFile, instanceIdentity.certificate());
@@ -255,7 +255,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
} catch (ZtsClientException e) {
if (e.getErrorCode() == 403 && e.getDescription().startsWith("Certificate revoked")) {
context.log(logger, Level.SEVERE, "Certificate cannot be refreshed as it is revoked by ZTS - re-registering the instance now", e);
- registerIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, identityType);
+ registerIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, identityType, identity);
} else {
throw e;
}
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java
index 20359410321..f2f690106fa 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java
@@ -16,6 +16,7 @@ import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeMembers
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeRepository;
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeSpec;
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeState;
+import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.reports.DropDocumentsReport;
import com.yahoo.vespa.hosted.node.admin.configserver.orchestrator.Orchestrator;
import com.yahoo.vespa.hosted.node.admin.configserver.orchestrator.OrchestratorException;
import com.yahoo.vespa.hosted.node.admin.container.Container;
@@ -29,6 +30,7 @@ import com.yahoo.vespa.hosted.node.admin.maintenance.acl.AclMaintainer;
import com.yahoo.vespa.hosted.node.admin.maintenance.identity.CredentialsMaintainer;
import com.yahoo.vespa.hosted.node.admin.maintenance.servicedump.VespaServiceDumper;
import com.yahoo.vespa.hosted.node.admin.nodeadmin.ConvergenceException;
+import com.yahoo.vespa.hosted.node.admin.task.util.file.FileFinder;
import java.time.Clock;
import java.time.Duration;
@@ -228,6 +230,12 @@ public class NodeAgentImpl implements NodeAgent {
changed = true;
}
+ Optional<DropDocumentsReport> report = context.node().reports().getReport(DropDocumentsReport.reportId(), DropDocumentsReport.class);
+ if (report.isPresent() && report.get().startedAt() == null && report.get().readiedAt() != null) {
+ newNodeAttributes.withReport(DropDocumentsReport.reportId(), report.get().withStartedAt(clock.millis()).toJsonNode());
+ changed = true;
+ }
+
if (changed) {
context.log(logger, "Publishing new set of attributes to node repo: %s -> %s",
currentNodeAttributes, newNodeAttributes);
@@ -433,6 +441,21 @@ public class NodeAgentImpl implements NodeAgent {
.orElse(false);
}
+ private void dropDocsIfNeeded(NodeAgentContext context, Optional<Container> container) {
+ Optional<DropDocumentsReport> report = context.node().reports()
+ .getReport(DropDocumentsReport.reportId(), DropDocumentsReport.class);
+ if (report.isEmpty() || report.get().readiedAt() != null) return;
+
+ if (report.get().droppedAt() == null) {
+ container.ifPresent(c -> removeContainer(context, c, List.of("Dropping documents"), true));
+ FileFinder.from(context.paths().underVespaHome("var/db/vespa/search")).deleteRecursively(context);
+ nodeRepository.updateNodeAttributes(context.node().hostname(),
+ new NodeAttributes().withReport(DropDocumentsReport.reportId(), report.get().withDroppedAt(clock.millis()).toJsonNode()));
+ }
+
+ throw ConvergenceException.ofTransient("Documents already dropped, waiting for signal to start the container");
+ }
+
public void converge(NodeAgentContext context) {
try {
doConverge(context);
@@ -494,6 +517,7 @@ public class NodeAgentImpl implements NodeAgent {
context.log(logger, "Waiting for image to download " + context.node().wantedDockerImage().get().asString());
return;
}
+ dropDocsIfNeeded(context, container);
container = removeContainerIfNeededUpdateContainerState(context, container);
credentialsMaintainers.forEach(maintainer -> maintainer.converge(context));
if (container.isEmpty()) {
diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImplTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImplTest.java
index b8b72308bdd..2db5314dbf2 100644
--- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImplTest.java
+++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImplTest.java
@@ -14,6 +14,7 @@ import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeReposit
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeSpec;
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeState;
import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.OrchestratorStatus;
+import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.reports.DropDocumentsReport;
import com.yahoo.vespa.hosted.node.admin.configserver.orchestrator.Orchestrator;
import com.yahoo.vespa.hosted.node.admin.configserver.orchestrator.OrchestratorException;
import com.yahoo.vespa.hosted.node.admin.container.Container;
@@ -27,6 +28,7 @@ import com.yahoo.vespa.hosted.node.admin.maintenance.acl.AclMaintainer;
import com.yahoo.vespa.hosted.node.admin.maintenance.identity.CredentialsMaintainer;
import com.yahoo.vespa.hosted.node.admin.maintenance.servicedump.VespaServiceDumper;
import com.yahoo.vespa.hosted.node.admin.nodeadmin.ConvergenceException;
+import com.yahoo.vespa.hosted.node.admin.task.util.file.UnixPath;
import com.yahoo.vespa.test.file.TestFileSystem;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -38,8 +40,11 @@ import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.function.BiFunction;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
@@ -739,6 +744,56 @@ public class NodeAgentImplTest {
inOrder.verify(orchestrator, times(1)).resume(eq(hostName));
}
+ @Test
+ void drop_all_documents() {
+ InOrder inOrder = inOrder(orchestrator, nodeRepository);
+ BiFunction<NodeState, DropDocumentsReport, NodeSpec> specBuilder = (state, report) -> (report == null ?
+ nodeBuilder(state) : nodeBuilder(state).report(DropDocumentsReport.reportId(), report.toJsonNode()))
+ .wantedDockerImage(dockerImage).currentDockerImage(dockerImage)
+ .build();
+ NodeAgentImpl nodeAgent = makeNodeAgent(dockerImage, true, Duration.ofSeconds(30));
+
+ NodeAgentContext context = createContext(specBuilder.apply(NodeState.active, null));
+ UnixPath indexPath = new UnixPath(context.paths().underVespaHome("var/db/vespa/search/cluster.foo/0/doc")).createParents().createNewFile();
+ mockGetContainer(dockerImage, ContainerResources.from(2, 2, 16), true);
+ assertTrue(indexPath.exists());
+
+ // Initially no changes, index is not dropped
+ nodeAgent.converge(context);
+ assertTrue(indexPath.exists());
+ inOrder.verifyNoMoreInteractions();
+
+ context = createContext(specBuilder.apply(NodeState.active, new DropDocumentsReport(1L, null, null, null)));
+ nodeAgent.converge(context);
+ verify(containerOperations).removeContainer(eq(context), any());
+ assertFalse(indexPath.exists());
+ inOrder.verify(nodeRepository).updateNodeAttributes(eq(hostName), eq(new NodeAttributes().withReport(DropDocumentsReport.reportId(), new DropDocumentsReport(1L, clock.millis(), null, null).toJsonNode())));
+ inOrder.verifyNoMoreInteractions();
+
+ // After droppedAt and before readiedAt are set, we cannot proceed
+ mockGetContainer(null, false);
+ context = createContext(specBuilder.apply(NodeState.active, new DropDocumentsReport(1L, 2L, null, null)));
+ nodeAgent.converge(context);
+ verify(containerOperations, never()).removeContainer(eq(context), any());
+ verify(containerOperations, never()).startContainer(eq(context));
+ inOrder.verifyNoMoreInteractions();
+
+ context = createContext(specBuilder.apply(NodeState.active, new DropDocumentsReport(1L, 2L, 3L, null)));
+ nodeAgent.converge(context);
+ verify(containerOperations).startContainer(eq(context));
+ inOrder.verifyNoMoreInteractions();
+
+ mockGetContainer(dockerImage, ContainerResources.from(0, 2, 16), true);
+ clock.advance(Duration.ofSeconds(31));
+ nodeAgent.converge(context);
+ verify(containerOperations, times(1)).startContainer(eq(context));
+ verify(containerOperations, never()).removeContainer(eq(context), any());
+ inOrder.verify(nodeRepository).updateNodeAttributes(eq(hostName), eq(new NodeAttributes()
+ .withRebootGeneration(0)
+ .withReport(DropDocumentsReport.reportId(), new DropDocumentsReport(1L, 2L, 3L, clock.millis()).toJsonNode())));
+ inOrder.verifyNoMoreInteractions();
+ }
+
private void verifyThatContainerIsStopped(NodeState nodeState, Optional<ApplicationId> owner) {
NodeSpec.Builder nodeBuilder = nodeBuilder(nodeState)
.type(NodeType.tenant)
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityChecker.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityChecker.java
index c8b736cb25b..f3ea326a3c0 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityChecker.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityChecker.java
@@ -130,7 +130,7 @@ public class CapacityChecker {
Set<String> ipPool = host.ipConfig().pool().asSet();
for (var child : nodeChildren.get(host)) {
hostResources = hostResources.subtract(child.resources().justNumbers());
- occupiedIps += child.ipConfig().primary().stream().filter(ipPool::contains).count();
+ occupiedIps += (int)child.ipConfig().primary().stream().filter(ipPool::contains).count();
}
availableResources.put(host, new AllocationResources(hostResources, host.ipConfig().pool().asSet().size() - occupiedIps));
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java
index 890d190c24e..b3198a72d1b 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java
@@ -193,8 +193,8 @@ public class LoadBalancerProvisioner {
Optional<LoadBalancer> loadBalancer = db.readLoadBalancer(id);
LoadBalancer newLoadBalancer;
LoadBalancer.State fromState = loadBalancer.map(LoadBalancer::state).orElse(null);
- boolean recreateLoadBalancer = loadBalancer.isPresent() && (!inAccount(cloudAccount, loadBalancer.get())
- || !hasCorrectVisibility(loadBalancer.get(), zoneEndpoint));
+ boolean recreateLoadBalancer = loadBalancer.isPresent() && ( ! inAccount(cloudAccount, loadBalancer.get())
+ || ! hasCorrectVisibility(loadBalancer.get(), zoneEndpoint));
if (recreateLoadBalancer) {
// We have a load balancer, but with the wrong account or visibility.
// Load balancer must be removed before we can provision a new one with the wanted visibility
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java
index dfe01f5f1c3..bbe287fc034 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java
@@ -11,8 +11,10 @@ import com.yahoo.config.provision.NodeFlavors;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.TenantName;
import com.yahoo.config.provision.WireguardKey;
+import com.yahoo.slime.Cursor;
import com.yahoo.slime.Inspector;
import com.yahoo.slime.ObjectTraverser;
+import com.yahoo.slime.Slime;
import com.yahoo.slime.SlimeUtils;
import com.yahoo.slime.Type;
import com.yahoo.vespa.hosted.provision.LockedNodeList;
@@ -40,6 +42,7 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
+import java.util.stream.Stream;
import static com.yahoo.config.provision.NodeResources.DiskSpeed.fast;
import static com.yahoo.config.provision.NodeResources.DiskSpeed.slow;
@@ -54,9 +57,13 @@ import static com.yahoo.config.provision.NodeResources.StorageType.remote;
*/
public class NodePatcher {
+ // Same as in DropDocumentsReport.java
+ private static final String DROP_DOCUMENTS_REPORT = "dropDocuments";
+
private static final String WANT_TO_RETIRE = "wantToRetire";
private static final String WANT_TO_DEPROVISION = "wantToDeprovision";
private static final String WANT_TO_REBUILD = "wantToRebuild";
+ private static final String REPORTS = "reports";
private static final Set<String> RECURSIVE_FIELDS = Set.of(WANT_TO_RETIRE, WANT_TO_DEPROVISION);
private static final Set<String> IP_CONFIG_FIELDS = Set.of("ipAddresses",
"additionalIpAddresses",
@@ -133,7 +140,29 @@ public class NodePatcher {
throw new IllegalArgumentException("Could not set field '" + name + "'", e);
}
}
- nodeRepository.nodes().write(node, lock);
+ List<Node> nodes = List.of(node);
+ if (node.state() == Node.State.active && isInDocumentsDroppedState(root.field(REPORTS).field(DROP_DOCUMENTS_REPORT))) {
+ NodeList clusterNodes = nodeRepository.nodes()
+ .list(Node.State.active)
+ .except(node)
+ .owner(node.allocation().get().owner())
+ .cluster(node.allocation().get().membership().cluster().id());
+ boolean allNodesDroppedDocuments = clusterNodes.stream().allMatch(cNode ->
+ cNode.reports().getReport(DROP_DOCUMENTS_REPORT).map(report -> isInDocumentsDroppedState(report.getInspector())).orElse(false));
+ if (allNodesDroppedDocuments) {
+ nodes = Stream.concat(nodes.stream(), clusterNodes.stream())
+ .map(cNode -> {
+ Cursor reportRoot = new Slime().setObject();
+ Report report = cNode.reports().getReport(DROP_DOCUMENTS_REPORT).get();
+ report.toSlime(reportRoot);
+ reportRoot.setLong("readiedAt", clock.millis());
+
+ return cNode.with(cNode.reports().withReport(Report.fromSlime(DROP_DOCUMENTS_REPORT, reportRoot)));
+ })
+ .toList();
+ }
+ }
+ nodeRepository.nodes().write(nodes, lock);
}
}
@@ -202,18 +231,15 @@ public class NodePatcher {
.orElseGet(node.status()::wantToRebuild),
Agent.operator,
clock.instant());
- case "reports" :
+ case REPORTS:
return nodeWithPatchedReports(node, value);
- case "id" :
+ case "id":
return node.withId(asString(value));
case "diskGb":
- case "minDiskAvailableGb":
return node.with(node.flavor().with(node.flavor().resources().withDiskGb(value.asDouble())), Agent.operator, clock.instant());
case "memoryGb":
- case "minMainMemoryAvailableGb":
return node.with(node.flavor().with(node.flavor().resources().withMemoryGb(value.asDouble())), Agent.operator, clock.instant());
case "vcpu":
- case "minCpuCores":
return node.with(node.flavor().with(node.flavor().resources().withVcpu(value.asDouble())), Agent.operator, clock.instant());
case "fastDisk":
return node.with(node.flavor().with(node.flavor().resources().with(value.asBool() ? fast : slow)), Agent.operator, clock.instant());
@@ -244,18 +270,12 @@ public class NodePatcher {
}
private Node applyIpconfigField(Node node, String name, Inspector value, LockedNodeList nodes) {
- switch (name) {
- case "ipAddresses" -> {
- return IP.Config.verify(node.with(node.ipConfig().withPrimary(asStringSet(value))), nodes);
- }
- case "additionalIpAddresses" -> {
- return IP.Config.verify(node.with(node.ipConfig().withPool(node.ipConfig().pool().withIpAddresses(asStringSet(value)))), nodes);
- }
- case "additionalHostnames" -> {
- return IP.Config.verify(node.with(node.ipConfig().withPool(node.ipConfig().pool().withHostnames(asHostnames(value)))), nodes);
- }
- }
- throw new IllegalArgumentException("Could not apply field '" + name + "' on a node: No such modifiable field");
+ return switch (name) {
+ case "ipAddresses" -> IP.Config.verify(node.with(node.ipConfig().withPrimary(asStringSet(value))), nodes);
+ case "additionalIpAddresses" -> IP.Config.verify(node.with(node.ipConfig().withPool(node.ipConfig().pool().withIpAddresses(asStringSet(value)))), nodes);
+ case "additionalHostnames" -> IP.Config.verify(node.with(node.ipConfig().withPool(node.ipConfig().pool().withHostnames(asHostnames(value)))), nodes);
+ default -> throw new IllegalArgumentException("Could not apply field '" + name + "' on a node: No such modifiable field");
+ };
}
private Node nodeWithPatchedReports(Node node, Inspector reportsInspector) {
@@ -374,4 +394,9 @@ public class NodePatcher {
return Optional.of(field).filter(Inspector::valid).map(this::asBoolean);
}
+ private static boolean isInDocumentsDroppedState(Inspector report) {
+ if (!report.valid()) return false;
+ return report.field("droppedAt").valid() && !report.field("readiedAt").valid();
+ }
+
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/NodesV2ApiTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/NodesV2ApiTest.java
index c9e57c22d11..7affcfebdb3 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/NodesV2ApiTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/NodesV2ApiTest.java
@@ -647,6 +647,22 @@ public class NodesV2ApiTest {
Request.Method.PATCH),
"{\"message\":\"Updated dockerhost1.yahoo.com\"}");
assertFile(new Request("http://localhost:8080/nodes/v2/node/dockerhost1.yahoo.com"), "docker-node1-reports-4.json");
+
+ assertResponse(new Request("http://localhost:8080/nodes/v2/node/host1.yahoo.com",
+ Utf8.toBytes("{\"reports\": {\"dropDocuments\":{\"createdMillis\":25,\"droppedAt\":36}}}"),
+ Request.Method.PATCH),
+ "{\"message\":\"Updated host1.yahoo.com\"}");
+ tester.assertResponseContains(new Request("http://localhost:8080/nodes/v2/node/host1.yahoo.com"),
+ "{\"dropDocuments\":{\"createdMillis\":25,\"droppedAt\":36}}");
+
+ assertResponse(new Request("http://localhost:8080/nodes/v2/node/host10.yahoo.com",
+ Utf8.toBytes("{\"reports\": {\"dropDocuments\":{\"createdMillis\":49,\"droppedAt\":456}}}"),
+ Request.Method.PATCH),
+ "{\"message\":\"Updated host10.yahoo.com\"}");
+ tester.assertResponseContains(new Request("http://localhost:8080/nodes/v2/node/host10.yahoo.com"),
+ "{\"dropDocuments\":{\"createdMillis\":49,\"droppedAt\":456,\"readiedAt\":123}}");
+ tester.assertResponseContains(new Request("http://localhost:8080/nodes/v2/node/host1.yahoo.com"),
+ "{\"dropDocuments\":{\"createdMillis\":25,\"droppedAt\":36,\"readiedAt\":123}}");
}
@Test
@@ -906,13 +922,13 @@ public class NodesV2ApiTest {
// Test patching with overrides
tester.assertResponse(new Request("http://localhost:8080/nodes/v2/node/" + host,
- "{\"minDiskAvailableGb\":5432,\"minMainMemoryAvailableGb\":2345}".getBytes(StandardCharsets.UTF_8),
+ "{\"diskGb\":5432,\"memoryGb\":2345}".getBytes(StandardCharsets.UTF_8),
Request.Method.PATCH),
400,
- "{\"error-code\":\"BAD_REQUEST\",\"message\":\"Could not set field 'minMainMemoryAvailableGb': Can only override disk GB for configured flavor\"}");
+ "{\"error-code\":\"BAD_REQUEST\",\"message\":\"Could not set field 'memoryGb': Can only override disk GB for configured flavor\"}");
assertResponse(new Request("http://localhost:8080/nodes/v2/node/" + host,
- "{\"minDiskAvailableGb\":5432}".getBytes(StandardCharsets.UTF_8),
+ "{\"diskGb\":5432}".getBytes(StandardCharsets.UTF_8),
Request.Method.PATCH),
"{\"message\":\"Updated " + host + "\"}");
tester.assertResponseContains(new Request("http://localhost:8080/nodes/v2/node/" + host),
diff --git a/parent/pom.xml b/parent/pom.xml
index 8d2f802e34b..76f4ef30dda 100644
--- a/parent/pom.xml
+++ b/parent/pom.xml
@@ -565,6 +565,17 @@
<version>${onnxruntime.version}</version>
</dependency>
<dependency>
+ <groupId>com.theokanning.openai-gpt3-java</groupId>
+ <artifactId>service</artifactId>
+ <version>${openai-gpt3.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
<groupId>com.yahoo.athenz</groupId>
<artifactId>athenz-cert-refresher</artifactId>
<version>${athenz.version}</version>
@@ -1009,6 +1020,11 @@
<version>${junit.version}</version>
</dependency>
<dependency>
+ <groupId>org.junit.jupiter</groupId>
+ <artifactId>junit-jupiter-params</artifactId>
+ <version>${junit.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
<version>${junit.version}</version>
@@ -1166,6 +1182,7 @@
<netty.version>4.1.86.Final</netty.version>
<netty-tcnative.version>2.0.54.Final</netty-tcnative.version>
<onnxruntime.version>1.13.1</onnxruntime.version> <!-- WARNING: sync cloud-tenant-base-dependencies-enforcer/pom.xml -->
+ <openai-gpt3.version>0.12.0</openai-gpt3.version>
<org.json.version>20230227</org.json.version>
<org.lz4.version>1.8.0</org.lz4.version>
<prometheus.client.version>0.6.0</prometheus.client.version>
diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/ResultMetrics.java b/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/ResultMetrics.java
index 11103a2a66a..ef65d9e2efa 100644
--- a/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/ResultMetrics.java
+++ b/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/ResultMetrics.java
@@ -62,8 +62,8 @@ public class ResultMetrics {
}
private double percentile(double percentile) {
- int targetCount = (int) Math.round(totalQueries * percentile);
- int currentCount = 0;
+ long targetCount = Math.round(totalQueries * percentile);
+ long currentCount = 0;
int index = 0;
while (currentCount < targetCount && index < SLOTS) {
currentCount += latencyHistogram[index];
diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java
index e6db1dec7c3..eb8b0b9927b 100644
--- a/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java
+++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java
@@ -75,7 +75,7 @@ public class CachedPostingListCounter {
private void countUsingBitVector(byte[] nPostingListsForDocument, int postingListBitmap) {
for (int docId = 0; docId < nDocuments; docId++) {
- nPostingListsForDocument[docId] += Integer.bitCount(bitVector[docId] & postingListBitmap);
+ nPostingListsForDocument[docId] += (byte)Integer.bitCount(bitVector[docId] & postingListBitmap);
}
}
@@ -88,8 +88,7 @@ public class CachedPostingListCounter {
}
public CachedPostingListCounter rebuildCache() {
- MinMaxPriorityQueue<Entry> mostExpensive = MinMaxPriorityQueue
- .maximumSize(32).expectedSize(32).create();
+ MinMaxPriorityQueue<Entry> mostExpensive = MinMaxPriorityQueue.maximumSize(32).expectedSize(32).create();
synchronized (this) {
for (ObjectLongPair<int[]> p : frequency.keyValuesView()) {
mostExpensive.add(new Entry(p.getOne(), p.getTwo()));
diff --git a/searchcore/src/tests/proton/attribute/attribute_initializer/attribute_initializer_test.cpp b/searchcore/src/tests/proton/attribute/attribute_initializer/attribute_initializer_test.cpp
index 4af23a1d7fb..d2798c16065 100644
--- a/searchcore/src/tests/proton/attribute/attribute_initializer/attribute_initializer_test.cpp
+++ b/searchcore/src/tests/proton/attribute/attribute_initializer/attribute_initializer_test.cpp
@@ -277,11 +277,7 @@ TEST("require that reserved document is reinitialized during load")
auto read_view = mvav->make_read_view(IMultiValueAttribute::WeightedSetTag<const char*>(), stash);
ASSERT_TRUE(read_view != nullptr);
auto reserved_values = read_view->get_values(0u);
- EXPECT_EQUAL(1u, reserved_values.size());
- if (reserved_values.size() >= 1) {
- EXPECT_EQUAL(1, reserved_values[0].weight());
- EXPECT_EQUAL(vespalib::string(""), vespalib::string(reserved_values[0].value()));
- }
+ EXPECT_EQUAL(0u, reserved_values.size());
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/expression/Int16ResultNode.java b/searchlib/src/main/java/com/yahoo/searchlib/expression/Int16ResultNode.java
index ae7d0a67b2f..b0f98685578 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/expression/Int16ResultNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/expression/Int16ResultNode.java
@@ -80,7 +80,7 @@ public class Int16ResultNode extends NumericResultNode {
@Override
public void add(ResultNode rhs) {
- value += rhs.getInteger();
+ value += (short)rhs.getInteger();
}
@Override
@@ -90,7 +90,7 @@ public class Int16ResultNode extends NumericResultNode {
@Override
public void multiply(ResultNode rhs) {
- value *= rhs.getInteger();
+ value *= (short)rhs.getInteger();
}
@Override
@@ -101,7 +101,7 @@ public class Int16ResultNode extends NumericResultNode {
@Override
public void modulo(ResultNode rhs) {
- value %= rhs.getInteger();
+ value %= (short)rhs.getInteger();
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/expression/Int32ResultNode.java b/searchlib/src/main/java/com/yahoo/searchlib/expression/Int32ResultNode.java
index da31cbc236a..711b8f1bd3f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/expression/Int32ResultNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/expression/Int32ResultNode.java
@@ -80,7 +80,7 @@ public class Int32ResultNode extends NumericResultNode {
@Override
public void add(ResultNode rhs) {
- value += rhs.getInteger();
+ value += (int)rhs.getInteger();
}
@Override
@@ -90,7 +90,7 @@ public class Int32ResultNode extends NumericResultNode {
@Override
public void multiply(ResultNode rhs) {
- value *= rhs.getInteger();
+ value *= (int)rhs.getInteger();
}
@Override
@@ -101,7 +101,7 @@ public class Int32ResultNode extends NumericResultNode {
@Override
public void modulo(ResultNode rhs) {
- value %= rhs.getInteger();
+ value %= (int)rhs.getInteger();
}
@Override
@@ -122,7 +122,7 @@ public class Int32ResultNode extends NumericResultNode {
@Override
public Object getNumber() {
- return Integer.valueOf(value);
+ return value;
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNode.java b/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNode.java
index ae53cf45a6f..d6706ce1dfe 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/expression/Int8ResultNode.java
@@ -78,7 +78,7 @@ public class Int8ResultNode extends NumericResultNode {
@Override
public void add(ResultNode rhs) {
- value += rhs.getInteger();
+ value += (byte)rhs.getInteger();
}
@Override
@@ -88,7 +88,7 @@ public class Int8ResultNode extends NumericResultNode {
@Override
public void multiply(ResultNode rhs) {
- value *= rhs.getInteger();
+ value *= (byte)rhs.getInteger();
}
@Override
@@ -99,7 +99,7 @@ public class Int8ResultNode extends NumericResultNode {
@Override
public void modulo(ResultNode rhs) {
- value %= rhs.getInteger();
+ value %= (byte)rhs.getInteger();
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/expression/RawResultNode.java b/searchlib/src/main/java/com/yahoo/searchlib/expression/RawResultNode.java
index 5a0e056f254..d1dc46fc4d0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/expression/RawResultNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/expression/RawResultNode.java
@@ -18,8 +18,8 @@ public class RawResultNode extends SingleResultNode {
// The global class identifier shared with C++.
public static final int classId = registerClass(0x4000 + 54, RawResultNode.class);
- private static RawResultNode negativeInfinity = new RawResultNode();
- private static PositiveInfinityResultNode positiveInfinity = new PositiveInfinityResultNode();
+ private static final RawResultNode negativeInfinity = new RawResultNode();
+ private static final PositiveInfinityResultNode positiveInfinity = new PositiveInfinityResultNode();
// The raw value of this node.
private RawData value = null;
@@ -147,7 +147,7 @@ public class RawResultNode extends SingleResultNode {
@Override
public Object getValue() {
- return getString();
+ return value;
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/ranking/features/fieldmatch/FieldMatchMetrics.java b/searchlib/src/main/java/com/yahoo/searchlib/ranking/features/fieldmatch/FieldMatchMetrics.java
index 2b5efdb1ffe..5b6a53a7019 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/ranking/features/fieldmatch/FieldMatchMetrics.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/ranking/features/fieldmatch/FieldMatchMetrics.java
@@ -16,7 +16,7 @@ import static java.lang.Math.*;
public final class FieldMatchMetrics implements Cloneable {
/** The calculator creating this - given on initialization */
- private FieldMatchMetricsComputer source;
+ private final FieldMatchMetricsComputer source;
/** The trace accumulated during execution - empty if no tracing */
private final Trace trace = new Trace();
@@ -75,7 +75,7 @@ public final class FieldMatchMetrics implements Cloneable {
currentSequence=0;
segmentStarts.clear();
- queryLength=source.getQuery().getTerms().length;
+ queryLength = source.getQuery().getTerms().length;
}
/** Are these metrics representing a complete match */
@@ -93,7 +93,7 @@ public final class FieldMatchMetrics implements Cloneable {
*/
public float get(String name) {
try {
- Method getter=getClass().getMethod("get" + name.substring(0,1).toUpperCase() + name.substring(1));
+ Method getter = getClass().getMethod("get" + name.substring(0, 1).toUpperCase() + name.substring(1));
return ((Number)getter.invoke(this)).floatValue();
}
catch (NoSuchMethodException e) {
@@ -140,7 +140,7 @@ public final class FieldMatchMetrics implements Cloneable {
* segment or out of order
*/
public float getAbsoluteProximity() {
- if (pairs <1) return 0.1f;
+ if (pairs < 1) return 0.1f;
return proximity/pairs;
}
@@ -151,7 +151,7 @@ public final class FieldMatchMetrics implements Cloneable {
* following each other in sequence, and close to 0 if they are far from each other or out of order
*/
public float getUnweightedProximity() {
- if (pairs <1) return 1f;
+ if (pairs < 1) return 1f;
return unweightedProximity/pairs;
}
@@ -271,33 +271,33 @@ public final class FieldMatchMetrics implements Cloneable {
* <code>queryCompleteness * ( 1 - fieldCompletenessImportance) + fieldCompletenessImportance * fieldCompleteness</code>
*/
public float getCompleteness() {
- float fieldCompletenessImportance=source.getParameters().getFieldCompletenessImportance();
+ float fieldCompletenessImportance = source.getParameters().getFieldCompletenessImportance();
return getQueryCompleteness() * ( 1 - fieldCompletenessImportance) + fieldCompletenessImportance*getFieldCompleteness();
}
/** Returns how well the order of the terms agreed in segments: <code>1-outOfOrder/pairs</code> */
public float getOrderness() {
- if (pairs ==0) return 1f;
+ if (pairs == 0) return 1f;
return 1-(float)outOfOrder/pairs;
}
/** Returns the degree to which different terms are related (occurring in the same segment): <code>1-segments/(matches-1)</code> */
public float getRelatedness() {
- if (matches==0) return 0;
- if (matches==1) return 1;
- return 1-(float)(segments-1)/(matches-1);
+ if (matches == 0) return 0;
+ if (matches == 1) return 1;
+ return 1 - (float)(segments - 1) / (matches - 1);
}
/** Returns <code>longestSequence/matches</code> */
public float getLongestSequenceRatio() {
- if (matches==0) return 0;
- return (float)longestSequence/matches;
+ if (matches == 0) return 0;
+ return (float)longestSequence / matches;
}
/** Returns the closeness of the segments in the field: <code>1-segmentDistance/fieldLength</code> */
public float getSegmentProximity() {
- if (matches==0) return 0;
- return 1-segmentDistance/source.getField().terms().size();
+ if (matches == 0) return 0;
+ return 1 - segmentDistance / source.getField().terms().size();
}
/**
@@ -306,14 +306,14 @@ public final class FieldMatchMetrics implements Cloneable {
* This is absoluteProximity/average connectedness.
*/
public float getProximity() {
- float totalConnectedness=0;
- for (int i=1; i<queryLength; i++) {
- totalConnectedness+=Math.max(0.1,source.getQuery().getTerms()[i].getConnectedness());
+ float totalConnectedness = 0;
+ for (int i = 1; i < queryLength; i++) {
+ totalConnectedness += (float)Math.max(0.1, source.getQuery().getTerms()[i].getConnectedness());
}
- float averageConnectedness=0.1f;
- if (queryLength>1)
- averageConnectedness=totalConnectedness/(queryLength-1);
- return getAbsoluteProximity()/averageConnectedness;
+ float averageConnectedness = 0.1f;
+ if (queryLength > 1)
+ averageConnectedness = totalConnectedness / (queryLength - 1);
+ return getAbsoluteProximity() / averageConnectedness;
}
/**
@@ -378,7 +378,7 @@ public final class FieldMatchMetrics implements Cloneable {
* not only when the metrics are complete, because this metric is used to choose segments during calculation.</p>
*/
float getSegmentationScore() {
- if (segments==0) return 0;
+ if (segments == 0) return 0;
return getAbsoluteProximity() * getExactness() / (segments * segments);
}
@@ -389,7 +389,7 @@ public final class FieldMatchMetrics implements Cloneable {
/** Called once for every match */
void onMatch(int i, int j) {
- if (matches>=source.getField().terms().size()) return;
+ if (matches >= source.getField().terms().size()) return;
matches++;
weight += (float)source.getQuery().getTerms()[i].getWeight() / source.getQuery().getTotalTermWeight();
significance += source.getQuery().getTerms()[i].getSignificance() / source.getQuery().getTotalSignificance();
@@ -418,42 +418,42 @@ public final class FieldMatchMetrics implements Cloneable {
}
/** Called once when this value is calculated, before onComplete */
- void setOccurrence(float occurrence) { this.occurrence=occurrence; }
+ void setOccurrence(float occurrence) { this.occurrence = occurrence; }
/** Called once when this value is calculated, before onComplete */
- void setWeightedOccurrence(float weightedOccurrence) { this.weightedOccurrence=weightedOccurrence; }
+ void setWeightedOccurrence(float weightedOccurrence) { this.weightedOccurrence = weightedOccurrence; }
/** Called once when this value is calculated, before onComplete */
- void setAbsoluteOccurrence(float absoluteOccurrence) { this.absoluteOccurrence=absoluteOccurrence; }
+ void setAbsoluteOccurrence(float absoluteOccurrence) { this.absoluteOccurrence = absoluteOccurrence; }
/** Called once when this value is calculated, before onComplete */
- void setWeightedAbsoluteOccurrence(float weightedAbsoluteOccurrence) { this.weightedAbsoluteOccurrence=weightedAbsoluteOccurrence; }
+ void setWeightedAbsoluteOccurrence(float weightedAbsoluteOccurrence) { this.weightedAbsoluteOccurrence = weightedAbsoluteOccurrence; }
/** Called once when this value is calculated, before onComplete */
- void setSignificantOccurrence(float significantOccurrence) { this.significantOccurrence =significantOccurrence; }
+ void setSignificantOccurrence(float significantOccurrence) { this.significantOccurrence = significantOccurrence; }
/** Called once when matching is complete */
void onComplete() {
// segment distance - calculated from sorted segment starts
- if (segmentStarts.size()<=1) {
- segmentDistance=0;
+ if (segmentStarts.size() <= 1) {
+ segmentDistance = 0;
}
else {
Collections.sort(segmentStarts);
- for (int i=1; i<segmentStarts.size(); i++) {
- segmentDistance+=segmentStarts.get(i)-segmentStarts.get(i-1)+1;
+ for (int i = 1; i < segmentStarts.size(); i++) {
+ segmentDistance += segmentStarts.get(i) - segmentStarts.get(i - 1) + 1;
}
}
- if (head==-1) head=0;
- if (tail==-1) tail=0;
+ if (head == -1) head = 0;
+ if (tail == -1) tail = 0;
}
// Events on pairs ----------
/** Called when <i>any</i> pair is encountered */
void onPair(int i, int j, int previousJ) {
- int distance = j-previousJ-1;
+ int distance = j - previousJ - 1;
if (distance < 0) distance++; // Discontinuity where the two terms are in the same position
if (abs(distance) > source.getParameters().getProximityLimit()) return; // Contribution=0
@@ -463,7 +463,7 @@ public final class FieldMatchMetrics implements Cloneable {
unweightedProximity += pairProximity;
float connectedness = source.getQuery().getTerms()[i].getConnectedness();
- proximity += pow(pairProximity, connectedness/0.1) * max(0.1, connectedness);
+ proximity += (float)pow(pairProximity, connectedness / 0.1) * (float)max(0.1, connectedness);
pairs++;
}
@@ -498,8 +498,8 @@ public final class FieldMatchMetrics implements Cloneable {
@Override
public FieldMatchMetrics clone() {
try {
- FieldMatchMetrics clone=(FieldMatchMetrics)super.clone();
- clone.segmentStarts=new ArrayList<>(segmentStarts);
+ FieldMatchMetrics clone = (FieldMatchMetrics)super.clone();
+ clone.segmentStarts = new ArrayList<>(segmentStarts);
return clone;
}
catch (CloneNotSupportedException e) {
@@ -514,19 +514,19 @@ public final class FieldMatchMetrics implements Cloneable {
public String toStringDump() {
try {
- StringBuilder b=new StringBuilder();
+ StringBuilder b = new StringBuilder();
for (Method m : this.getClass().getDeclaredMethods()) {
if ( ! m.getName().startsWith("get")) continue;
- if (m.getReturnType()!=Integer.TYPE && m.getReturnType()!=Float.TYPE) continue;
- if ( m.getParameterTypes().length!=0 ) continue;
+ if (m.getReturnType() != Integer.TYPE && m.getReturnType() != Float.TYPE) continue;
+ if ( m.getParameterTypes().length != 0 ) continue;
- Object value=m.invoke(this,new Object[0]);
- b.append(m.getName().substring(3,4).toLowerCase() + m.getName().substring(4) + ": " + value + "\n");
+ Object value = m.invoke(this, new Object[0]);
+ b.append(m.getName().substring(3, 4).toLowerCase() + m.getName().substring(4) + ": " + value + "\n");
}
return b.toString();
}
catch (Exception e) {
- throw new RuntimeException("Programming error",e);
+ throw new RuntimeException("Programming error", e);
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
index 949e1f026f7..df721a4309e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
@@ -70,13 +70,13 @@ public final class GBDTNode extends ExpressionNode {
int offset = (int)nextValue - MAX_LEAF_VALUE;
boolean comparisonIsTrue = false;
if (offset < MAX_VARIABLES) {
- comparisonIsTrue = context.getDouble(offset)<values[pc++];
+ comparisonIsTrue = context.getDouble(offset) < values[pc++];
}
- else if (offset < MAX_VARIABLES*2) {
- comparisonIsTrue = context.getDouble(offset-MAX_VARIABLES)==values[pc++];
+ else if (offset < MAX_VARIABLES * 2) {
+ comparisonIsTrue = context.getDouble(offset - MAX_VARIABLES) == values[pc++];
}
- else if (offset<MAX_VARIABLES*3) {
- double testValue = context.getDouble(offset-MAX_VARIABLES*2);
+ else if (offset < MAX_VARIABLES * 3) {
+ double testValue = context.getDouble(offset - MAX_VARIABLES * 2);
int setValuesLeft = (int)values[pc++];
while (setValuesLeft > 0) { // test each value in the set
setValuesLeft--;
@@ -88,13 +88,13 @@ public final class GBDTNode extends ExpressionNode {
pc += setValuesLeft; // jump to after the set
}
else { // offset<MAX_VARIABLES*4
- comparisonIsTrue = ! (context.getDouble(offset-MAX_VARIABLES*3)>=values[pc++]);
+ comparisonIsTrue = ! (context.getDouble(offset - MAX_VARIABLES * 3) >= values[pc++]);
}
if (comparisonIsTrue)
pc++; // true branch - skip the jump value
else
- pc += values[pc]; // false branch - jump
+ pc += (int)values[pc]; // false branch - jump
}
else { // a leaf
return nextValue;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/gbdt/GbdtConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/gbdt/GbdtConverterTestCase.java
index d846d322720..8a33f320bb0 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/gbdt/GbdtConverterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/gbdt/GbdtConverterTestCase.java
@@ -3,14 +3,12 @@ package com.yahoo.searchlib.gbdt;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import org.junit.After;
-import org.junit.Before;
import org.junit.Test;
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
-import java.security.Permission;
+import java.nio.charset.StandardCharsets;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertEquals;
@@ -21,36 +19,6 @@ import static org.junit.Assert.fail;
*/
public class GbdtConverterTestCase {
- @Before
- @SuppressWarnings("removal")
- public void enableSecurityManager() {
- System.setSecurityManager(new NoExitSecurityManager());
- }
-
- @After
- @SuppressWarnings("removal")
- public void disableSecurityManager() {
- System.setSecurityManager(null);
- }
-
- @Test
- public void testOnlyOneArgumentIsAccepted() throws UnsupportedEncodingException {
- assertError("Usage: GbdtConverter <filename>\n", new String[0]);
- assertError("Usage: GbdtConverter <filename>\n", new String[] { "foo", "bar" });
- }
-
- @Test
- public void testFileIsFound() throws UnsupportedEncodingException {
- assertError("Could not find file 'not.found'.\n", new String[] { "not.found" });
- }
-
- @Test
- public void testFileParsingExceptionIsCaught() throws UnsupportedEncodingException {
- assertError("An error occurred while parsing the content of file 'src/test/files/gbdt_err.xml': " +
- "Node 'Unknown' has no 'DecisionTree' children.\n",
- new String[] { "src/test/files/gbdt_err.xml" });
- }
-
@Test
public void testEmptyTreesAreIgnored() throws Exception {
assertConvert("src/test/files/gbdt_empty_tree.xml",
@@ -125,7 +93,7 @@ public class GbdtConverterTestCase {
ByteArrayOutputStream out = new ByteArrayOutputStream();
System.setOut(new PrintStream(out));
GbdtConverter.main(new String[] { gbdtModelFile });
- String actualExpression = out.toString("UTF-8");
+ String actualExpression = out.toString(StandardCharsets.UTF_8);
assertEquals(expectedExpression, actualExpression);
assertNotNull(new RankingExpression(actualExpression));
}
@@ -138,26 +106,7 @@ public class GbdtConverterTestCase {
fail();
} catch (ExitException e) {
assertEquals(1, e.status);
- assertEquals(expected, err.toString("UTF-8"));
- }
- }
-
- @SuppressWarnings("removal")
- private static class NoExitSecurityManager extends SecurityManager {
-
- @Override
- public void checkPermission(Permission perm) {
- // allow anything
- }
-
- @Override
- public void checkPermission(Permission perm, Object context) {
- // allow anything
- }
-
- @Override
- public void checkExit(int status) {
- throw new ExitException(status);
+ assertEquals(expected, err.toString(StandardCharsets.UTF_8));
}
}
@@ -169,4 +118,5 @@ public class GbdtConverterTestCase {
this.status = status;
}
}
+
}
diff --git a/searchlib/src/tests/attribute/attribute_test.cpp b/searchlib/src/tests/attribute/attribute_test.cpp
index d2d3ccaad23..3a1a5b457ef 100644
--- a/searchlib/src/tests/attribute/attribute_test.cpp
+++ b/searchlib/src/tests/attribute/attribute_test.cpp
@@ -913,7 +913,7 @@ AttributeTest::testSingle()
AttributePtr ptr = createAttribute("sv-post-int32", cfg);
ptr->updateStat(true);
EXPECT_EQ(338972u, ptr->getStatus().getAllocated());
- EXPECT_EQ(101492u, ptr->getStatus().getUsed());
+ EXPECT_EQ(101632u, ptr->getStatus().getUsed());
addDocs(ptr, numDocs);
testSingle<IntegerAttribute, AttributeVector::largeint_t, int32_t>(ptr, values);
}
@@ -935,7 +935,7 @@ AttributeTest::testSingle()
AttributePtr ptr = createAttribute("sv-post-float", cfg);
ptr->updateStat(true);
EXPECT_EQ(338972u, ptr->getStatus().getAllocated());
- EXPECT_EQ(101492u, ptr->getStatus().getUsed());
+ EXPECT_EQ(101632u, ptr->getStatus().getUsed());
addDocs(ptr, numDocs);
testSingle<FloatingPointAttribute, double, float>(ptr, values);
}
@@ -948,7 +948,7 @@ AttributeTest::testSingle()
AttributePtr ptr = createAttribute("sv-string", Config(BasicType::STRING, CollectionType::SINGLE));
ptr->updateStat(true);
EXPECT_EQ(116528u + sizeof_large_string_entry, ptr->getStatus().getAllocated());
- EXPECT_EQ(52760u + sizeof_large_string_entry, ptr->getStatus().getUsed());
+ EXPECT_EQ(52844u + sizeof_large_string_entry, ptr->getStatus().getUsed());
addDocs(ptr, numDocs);
testSingle<StringAttribute, string, string>(ptr, values);
}
@@ -958,7 +958,7 @@ AttributeTest::testSingle()
AttributePtr ptr = createAttribute("sv-fs-string", cfg);
ptr->updateStat(true);
EXPECT_EQ(344848u + sizeof_large_string_entry, ptr->getStatus().getAllocated());
- EXPECT_EQ(104408u + sizeof_large_string_entry, ptr->getStatus().getUsed());
+ EXPECT_EQ(104556u + sizeof_large_string_entry, ptr->getStatus().getUsed());
addDocs(ptr, numDocs);
testSingle<StringAttribute, string, string>(ptr, values);
}
@@ -1110,7 +1110,7 @@ AttributeTest::testArray()
AttributePtr ptr = createAttribute("a-fs-int32", cfg);
ptr->updateStat(true);
EXPECT_EQ(844116u, ptr->getStatus().getAllocated());
- EXPECT_EQ(581232u, ptr->getStatus().getUsed());
+ EXPECT_EQ(581372u, ptr->getStatus().getUsed());
addDocs(ptr, numDocs);
testArray<IntegerAttribute, AttributeVector::largeint_t>(ptr, values);
}
@@ -1129,7 +1129,7 @@ AttributeTest::testArray()
AttributePtr ptr = createAttribute("a-fs-float", cfg);
ptr->updateStat(true);
EXPECT_EQ(844116u, ptr->getStatus().getAllocated());
- EXPECT_EQ(581232u, ptr->getStatus().getUsed());
+ EXPECT_EQ(581372u, ptr->getStatus().getUsed());
addDocs(ptr, numDocs);
testArray<FloatingPointAttribute, double>(ptr, values);
}
@@ -1141,7 +1141,7 @@ AttributeTest::testArray()
AttributePtr ptr = createAttribute("a-string", Config(BasicType::STRING, CollectionType::ARRAY));
ptr->updateStat(true);
EXPECT_EQ(599784u + sizeof_large_string_entry, ptr->getStatus().getAllocated());
- EXPECT_EQ(532480u + sizeof_large_string_entry, ptr->getStatus().getUsed());
+ EXPECT_EQ(532564u + sizeof_large_string_entry, ptr->getStatus().getUsed());
addDocs(ptr, numDocs);
testArray<StringAttribute, string>(ptr, values);
}
@@ -1151,7 +1151,7 @@ AttributeTest::testArray()
AttributePtr ptr = createAttribute("afs-string", cfg);
ptr->updateStat(true);
EXPECT_EQ(849992u + sizeof_large_string_entry, ptr->getStatus().getAllocated());
- EXPECT_EQ(584148u + sizeof_large_string_entry, ptr->getStatus().getUsed());
+ EXPECT_EQ(584296u + sizeof_large_string_entry, ptr->getStatus().getUsed());
addDocs(ptr, numDocs);
testArray<StringAttribute, string>(ptr, values);
}
@@ -1718,7 +1718,7 @@ AttributeTest::testStatus()
ptr->commit(true);
EXPECT_EQ(ptr->getStatus().getNumDocs(), 100u);
EXPECT_EQ(ptr->getStatus().getNumValues(), 100u);
- EXPECT_EQ(ptr->getStatus().getNumUniqueValues(), 1u);
+ EXPECT_EQ(ptr->getStatus().getNumUniqueValues(), 2u);
size_t expUsed = 0;
expUsed += 1 * InternalNodeSize + 1 * LeafNodeSize; // enum store tree
expUsed += 1 * 32; // enum store (uniquevalues * bytes per entry)
@@ -1741,7 +1741,7 @@ AttributeTest::testStatus()
ptr->commit(true);
EXPECT_EQ(ptr->getStatus().getNumDocs(), numDocs);
EXPECT_EQ(ptr->getStatus().getNumValues(), numDocs*numValuesPerDoc);
- EXPECT_EQ(ptr->getStatus().getNumUniqueValues(), numUniq);
+ EXPECT_EQ(ptr->getStatus().getNumUniqueValues(), numUniq + 1);
size_t expUsed = 0;
expUsed += 1 * InternalNodeSize + 1 * LeafNodeSize; // Approximate enum store tree
expUsed += 272; // TODO Approximate... enum store (16 unique values, 17 bytes per entry)
@@ -2145,12 +2145,12 @@ AttributeTest::test_default_value_ref_count_is_updated_after_shrink_lid_space()
const auto & iattr = dynamic_cast<const search::IntegerAttributeTemplate<int32_t> &>(*attr);
attr->addReservedDoc();
attr->addDocs(10);
- EXPECT_EQ(11u, get_default_value_ref_count(*attr, iattr.defaultValue()));
+ EXPECT_EQ(12u, get_default_value_ref_count(*attr, iattr.defaultValue()));
attr->compactLidSpace(6);
- EXPECT_EQ(11u, get_default_value_ref_count(*attr, iattr.defaultValue()));
+ EXPECT_EQ(12u, get_default_value_ref_count(*attr, iattr.defaultValue()));
attr->shrinkLidSpace();
EXPECT_EQ(6u, attr->getNumDocs());
- EXPECT_EQ(6u, get_default_value_ref_count(*attr, iattr.defaultValue()));
+ EXPECT_EQ(7u, get_default_value_ref_count(*attr, iattr.defaultValue()));
}
template <typename AttributeType>
@@ -2170,7 +2170,7 @@ AttributeTest::requireThatAddressSpaceUsageIsReported(const Config &config, bool
AddressSpaceUsage after = attrPtr->getAddressSpaceUsage();
if (attrPtr->hasEnum()) {
LOG(info, "requireThatAddressSpaceUsageIsReported(%s): Has enum", attrName.c_str());
- EXPECT_EQ(before.enum_store_usage().used(), 1u);
+ EXPECT_EQ(before.enum_store_usage().used(), 2u);
EXPECT_EQ(before.enum_store_usage().dead(), 1u);
EXPECT_GT(after.enum_store_usage().used(), before.enum_store_usage().used());
EXPECT_GE(after.enum_store_usage().limit(), before.enum_store_usage().limit());
diff --git a/searchlib/src/tests/attribute/enumeratedsave/enumeratedsave_test.cpp b/searchlib/src/tests/attribute/enumeratedsave/enumeratedsave_test.cpp
index 820f39089d1..5501c99652b 100644
--- a/searchlib/src/tests/attribute/enumeratedsave/enumeratedsave_test.cpp
+++ b/searchlib/src/tests/attribute/enumeratedsave/enumeratedsave_test.cpp
@@ -183,7 +183,7 @@ MemAttr::bufEqual(const Buffer &lhs, const Buffer &rhs) const
return false;
if (lhs.get() == NULL)
return true;
- if (!EXPECT_TRUE(lhs->getDataLen() == rhs->getDataLen()))
+ if (!EXPECT_EQUAL(lhs->getDataLen(), rhs->getDataLen()))
return false;
if (!EXPECT_TRUE(vespalib::memcmp_safe(lhs->getData(), rhs->getData(),
lhs->getDataLen()) == 0))
@@ -243,8 +243,9 @@ EnumeratedSaveTest::populate(IntegerAttribute &v, unsigned seed,
int weight = 1;
for(size_t i(0), m(v.getNumDocs()); i < m; i++) {
v.clearDoc(i);
- if (i == 9)
+ if (i == 9) {
continue;
+ }
if (i == 7) {
if (v.hasMultiValue()) {
v.append(i, -42, 27);
@@ -270,7 +271,7 @@ EnumeratedSaveTest::populate(IntegerAttribute &v, unsigned seed,
i + 1);
}
} else {
- EXPECT_TRUE( v.update(i, lrand48() & mask) );
+ EXPECT_TRUE( v.update(i, rnd.lrand48() & mask) );
}
}
v.commit();
@@ -288,8 +289,9 @@ EnumeratedSaveTest::populate(FloatingPointAttribute &v, unsigned seed,
int weight = 1;
for(size_t i(0), m(v.getNumDocs()); i < m; i++) {
v.clearDoc(i);
- if (i == 9)
+ if (i == 9) {
continue;
+ }
if (i == 7) {
if (v.hasMultiValue()) {
v.append(i, -42.0, 27);
@@ -315,7 +317,7 @@ EnumeratedSaveTest::populate(FloatingPointAttribute &v, unsigned seed,
i + 1);
}
} else {
- EXPECT_TRUE( v.update(i, lrand48()) );
+ EXPECT_TRUE( v.update(i, rnd.lrand48()) );
}
}
v.commit();
@@ -332,8 +334,9 @@ EnumeratedSaveTest::populate(StringAttribute &v, unsigned seed,
int weight = 1;
for(size_t i(0), m(v.getNumDocs()); i < m; i++) {
v.clearDoc(i);
- if (i == 9)
+ if (i == 9) {
continue;
+ }
if (i == 7) {
if (v.hasMultiValue()) {
v.append(i, "foo", 27);
@@ -712,9 +715,9 @@ EnumeratedSaveTest::test(BasicType bt, CollectionType ct,
Config check_cfg(cfg);
check_cfg.setFastSearch(true);
- checkLoad<VectorType, BufferType>(check_cfg, pref + "0_ee", v0);
- checkLoad<VectorType, BufferType>(check_cfg, pref + "1_ee", v1);
- checkLoad<VectorType, BufferType>(check_cfg, pref + "2_ee", v2);
+ TEST_DO((checkLoad<VectorType, BufferType>(check_cfg, pref + "0_ee", v0)));
+ TEST_DO((checkLoad<VectorType, BufferType>(check_cfg, pref + "1_ee", v1)));
+ TEST_DO((checkLoad<VectorType, BufferType>(check_cfg, pref + "2_ee", v2)));
TEST_DO((testReload<VectorType, BufferType>(v0, v1, v2,
mv0, mv1, mv2,
diff --git a/searchlib/src/tests/attribute/enumstore/enumstore_test.cpp b/searchlib/src/tests/attribute/enumstore/enumstore_test.cpp
index b3c7516777c..2b01c266e80 100644
--- a/searchlib/src/tests/attribute/enumstore/enumstore_test.cpp
+++ b/searchlib/src/tests/attribute/enumstore/enumstore_test.cpp
@@ -180,15 +180,35 @@ TYPED_TEST(FloatEnumStoreTest, numbers_can_be_inserted_and_retrieved)
}
}
+TEST(EnumStoreTest, default_value_is_present)
+{
+ StringEnumStore ses(false, DictionaryConfig::Type::BTREE);
+ using EntryType = StringEnumStore::EntryType;
+ EntryType undefined = attribute::getUndefined<EntryType>();
+ EnumIndex idx;
+ EXPECT_TRUE(ses.find_index(undefined, idx));
+ EXPECT_TRUE(idx.valid());
+ EXPECT_EQ(ses.get_default_value_ref().load_relaxed(), idx);
+ ses.clear_default_value_ref();
+ EXPECT_FALSE(ses.find_index(undefined, idx));
+ EXPECT_FALSE(ses.get_default_value_ref().load_relaxed().valid());
+ ses.setup_default_value_ref();
+ idx = EnumIndex();
+ EXPECT_TRUE(ses.find_index(undefined, idx));
+ EXPECT_TRUE(idx.valid());
+ EXPECT_EQ(ses.get_default_value_ref().load_relaxed(), idx);
+}
+
TEST(EnumStoreTest, test_find_folded_on_string_enum_store)
{
StringEnumStore ses(false, DictionaryConfig::Type::BTREE);
+ using EntryType = StringEnumStore::EntryType;
std::vector<EnumIndex> indices;
std::vector<std::string> unique({"", "one", "two", "TWO", "Two", "three"});
for (std::string &str : unique) {
EnumIndex idx = ses.insert(str.c_str());
indices.push_back(idx);
- EXPECT_EQ(1u, ses.get_ref_count(idx));
+ EXPECT_EQ((str == attribute::getUndefined<EntryType>()) ? 2u : 1u, ses.get_ref_count(idx));
}
ses.freeze_dictionary();
for (uint32_t i = 0; i < indices.size(); ++i) {
@@ -233,13 +253,14 @@ void
StringEnumStoreTest::testInsert(bool hasPostings)
{
StringEnumStore ses(hasPostings, DictionaryConfig::Type::BTREE);
+ using EntryType = StringEnumStore::EntryType;
std::vector<EnumIndex> indices;
std::vector<std::string> unique = {"", "add", "enumstore", "unique"};
for (const auto & i : unique) {
EnumIndex idx = ses.insert(i.c_str());
- EXPECT_EQ(1u, ses.get_ref_count(idx));
+ EXPECT_EQ((i == attribute::getUndefined<EntryType>()) ? 2u : 1u, ses.get_ref_count(idx));
indices.push_back(idx);
EXPECT_TRUE(ses.find_index(i.c_str(), idx));
}
@@ -253,7 +274,7 @@ StringEnumStoreTest::testInsert(bool hasPostings)
EnumIndex idx;
EXPECT_TRUE(ses.find_index(unique[i].c_str(), idx));
EXPECT_TRUE(idx == indices[i]);
- EXPECT_EQ(1u, ses.get_ref_count(indices[i]));
+ EXPECT_EQ((i == 0) ? 2u : 1u, ses.get_ref_count(indices[i]));
const char* value = nullptr;
EXPECT_TRUE(ses.get_value(indices[i], value));
EXPECT_TRUE(strcmp(unique[i].c_str(), value) == 0);
@@ -354,22 +375,22 @@ TEST(EnumStoreTest, address_space_usage_is_reported)
NumericEnumStore store(false, DictionaryConfig::Type::BTREE);
using vespalib::AddressSpace;
- EXPECT_EQ(AddressSpace(1, 1, ADDRESS_LIMIT), store.get_values_address_space_usage());
- EnumIndex idx1 = store.insert(10);
EXPECT_EQ(AddressSpace(2, 1, ADDRESS_LIMIT), store.get_values_address_space_usage());
- EnumIndex idx2 = store.insert(20);
+ EnumIndex idx1 = store.insert(10);
// Address limit increases because buffer is re-sized.
EXPECT_EQ(AddressSpace(3, 1, ADDRESS_LIMIT + 2), store.get_values_address_space_usage());
+ EnumIndex idx2 = store.insert(20);
+ EXPECT_EQ(AddressSpace(4, 1, ADDRESS_LIMIT + 2), store.get_values_address_space_usage());
dec_ref_count(store, idx1);
- EXPECT_EQ(AddressSpace(3, 2, ADDRESS_LIMIT + 2), store.get_values_address_space_usage());
+ EXPECT_EQ(AddressSpace(4, 2, ADDRESS_LIMIT + 2), store.get_values_address_space_usage());
dec_ref_count(store, idx2);
- EXPECT_EQ(AddressSpace(3, 3, ADDRESS_LIMIT + 2), store.get_values_address_space_usage());
+ EXPECT_EQ(AddressSpace(4, 3, ADDRESS_LIMIT + 2), store.get_values_address_space_usage());
}
TEST(EnumStoreTest, provided_memory_allocator_is_used)
{
AllocStats stats;
- NumericEnumStore ses(false, DictionaryConfig::Type::BTREE, std::make_unique<MemoryAllocatorObserver>(stats));
+ NumericEnumStore ses(false, DictionaryConfig::Type::BTREE, std::make_unique<MemoryAllocatorObserver>(stats), attribute::getUndefined<NumericEnumStore::EntryType>());
EXPECT_EQ(AllocStats(1, 0), stats);
}
@@ -539,6 +560,7 @@ TYPED_TEST_SUITE(LoaderTest, LoaderTestTypes);
TYPED_TEST(LoaderTest, store_is_instantiated_with_enumerated_loader)
{
+ this->store.clear_default_value_ref();
auto loader = this->store.make_enumerated_loader();
this->load_values(loader);
loader.allocate_enums_histogram();
@@ -554,6 +576,7 @@ TYPED_TEST(LoaderTest, store_is_instantiated_with_enumerated_loader)
TYPED_TEST(LoaderTest, store_is_instantiated_with_enumerated_postings_loader)
{
+ this->store.clear_default_value_ref();
auto loader = this->store.make_enumerated_postings_loader();
this->load_values(loader);
this->set_ref_count(0, 1, loader);
@@ -568,6 +591,7 @@ TYPED_TEST(LoaderTest, store_is_instantiated_with_enumerated_postings_loader)
TYPED_TEST(LoaderTest, store_is_instantiated_with_non_enumerated_loader)
{
+ this->store.clear_default_value_ref();
auto loader = this->store.make_non_enumerated_loader();
using MyValues = LoaderTestValues<typename TypeParam::EnumStoreType>;
loader.insert(MyValues::values[0], 100);
@@ -610,6 +634,7 @@ public:
void test_normalize_posting_lists(bool use_filter, bool one_filter);
void test_foreach_posting_list(bool one_filter);
static EntryRef fake_pidx() { return EntryRef(42); }
+ EnumIndex check_default_value_ref() const noexcept;
};
template <typename EnumStoreTypeAndDictionaryType>
@@ -775,6 +800,16 @@ EnumStoreDictionaryTest<EnumStoreTypeAndDictionaryType>::test_foreach_posting_li
clear_sample_values(large_population);
}
+template <typename EnumStoreTypeAndDictionaryType>
+EnumIndex
+EnumStoreDictionaryTest<EnumStoreTypeAndDictionaryType>::check_default_value_ref() const noexcept
+{
+ EnumIndex default_value_ref = store.get_default_value_ref().load_relaxed();
+ EXPECT_TRUE(default_value_ref.valid());
+ EXPECT_EQ(attribute::getUndefined<EntryType>(), store.get_value(default_value_ref));
+ return default_value_ref;
+}
+
using EnumStoreDictionaryTestTypes = ::testing::Types<BTreeNumericEnumStore, HybridNumericEnumStore, HashNumericEnumStore>;
TYPED_TEST_SUITE(EnumStoreDictionaryTest, EnumStoreDictionaryTestTypes);
@@ -875,6 +910,7 @@ TYPED_TEST(EnumStoreDictionaryTest, compact_worst_works)
updater.commit();
generation_t gen = 3;
inc_generation(gen, this->store);
+ // Compact dictionary
auto& dict = this->store.get_dictionary();
if (dict.get_has_btree_dictionary()) {
EXPECT_LT(CompactionStrategy::DEAD_BYTES_SLACK, dict.get_btree_memory_usage().deadBytes());
@@ -902,8 +938,31 @@ TYPED_TEST(EnumStoreDictionaryTest, compact_worst_works)
if (dict.get_has_hash_dictionary()) {
EXPECT_GT(CompactionStrategy::DEAD_BYTES_SLACK, dict.get_hash_memory_usage().deadBytes());
}
+ auto old_default_value_ref = this->check_default_value_ref();
+ // Compact values
+ EXPECT_LT(CompactionStrategy::DEAD_BYTES_SLACK, this->store.get_values_memory_usage().deadBytes());
+ compaction_strategy = CompactionStrategy::make_compact_all_active_buffers_strategy();
+ int compact_values_count = 0;
+ for (uint32_t i = 0; i < 2; ++i) {
+ this->store.update_stat(compaction_strategy);
+ auto remapper = this->store.consider_compact_values(compaction_strategy);
+ if (remapper) {
+ remapper->done();
+ ++compact_values_count;
+ } else {
+ break;
+ }
+ EXPECT_FALSE(this->store.consider_compact_values(compaction_strategy));
+ inc_generation(gen, this->store);
+ }
+ EXPECT_EQ(1, compact_values_count);
+ auto new_default_value_ref = this->check_default_value_ref();
+ EXPECT_NE(old_default_value_ref, new_default_value_ref);
+ EXPECT_GT(CompactionStrategy::DEAD_BYTES_SLACK, this->store.get_values_memory_usage().deadBytes());
+
std::vector<int32_t> exp_values;
std::vector<int32_t> values;
+ exp_values.push_back(std::numeric_limits<int32_t>::min());
for (int32_t i = 0; i < 20; ++i) {
exp_values.push_back(i);
}
diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp
index f354f635def..2c202d9131b 100644
--- a/searchlib/src/tests/query/streaming_query_test.cpp
+++ b/searchlib/src/tests/query/streaming_query_test.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/searchlib/query/streaming/query.h>
+#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h>
#include <vespa/searchlib/query/tree/querybuilder.h>
#include <vespa/searchlib/query/tree/simplequery.h>
#include <vespa/searchlib/query/tree/stackdumpcreator.h>
@@ -804,6 +805,42 @@ TEST("testSameElementEvaluate") {
EXPECT_TRUE(sameElem->evaluate());
}
+TEST("test_nearest_neighbor_query_node")
+{
+ QueryBuilder<SimpleQueryNodeTypes> builder;
+ constexpr double distance_threshold = 35.5;
+ constexpr int32_t id = 42;
+ constexpr int32_t weight = 1;
+ constexpr uint32_t target_num_hits = 100;
+ constexpr bool allow_approximate = false;
+ constexpr uint32_t explore_additional_hits = 800;
+ constexpr double raw_score = 0.5;
+ builder.add_nearest_neighbor_term("qtensor", "field", id, Weight(weight), target_num_hits, allow_approximate, explore_additional_hits, distance_threshold);
+ auto build_node = builder.build();
+ auto stack_dump = StackDumpCreator::create(*build_node);
+ QueryNodeResultFactory empty;
+ Query q(empty, stack_dump);
+ auto* qterm = dynamic_cast<QueryTerm *>(&q.getRoot());
+ EXPECT_TRUE(qterm != nullptr);
+ auto* node = dynamic_cast<NearestNeighborQueryNode *>(&q.getRoot());
+ EXPECT_TRUE(node != nullptr);
+ EXPECT_EQUAL(node, qterm->as_nearest_neighbor_query_node());
+ EXPECT_EQUAL("qtensor", node->get_query_tensor_name());
+ EXPECT_EQUAL("field", node->getIndex());
+ EXPECT_EQUAL(id, static_cast<int32_t>(node->uniqueId()));
+ EXPECT_EQUAL(weight, node->weight().percent());
+ EXPECT_EQUAL(distance_threshold, node->get_distance_threshold());
+ EXPECT_FALSE(node->get_raw_score().has_value());
+ EXPECT_FALSE(node->evaluate());
+ node->set_raw_score(raw_score);
+ EXPECT_TRUE(node->get_raw_score().has_value());
+ EXPECT_EQUAL(raw_score, node->get_raw_score().value());
+ EXPECT_TRUE(node->evaluate());
+ node->reset();
+ EXPECT_FALSE(node->get_raw_score().has_value());
+ EXPECT_FALSE(node->evaluate());
+}
+
TEST("Control the size of query terms") {
EXPECT_EQUAL(112u, sizeof(QueryTermSimple));
EXPECT_EQUAL(128u, sizeof(QueryTermUCS4));
diff --git a/searchlib/src/vespa/searchcommon/common/undefinedvalues.h b/searchlib/src/vespa/searchcommon/common/undefinedvalues.h
index bbe3198a8dc..a080648c054 100644
--- a/searchlib/src/vespa/searchcommon/common/undefinedvalues.h
+++ b/searchlib/src/vespa/searchcommon/common/undefinedvalues.h
@@ -24,6 +24,10 @@ inline constexpr double getUndefined<double>() {
return -std::numeric_limits<double>::quiet_NaN();
}
+template <>
+inline constexpr const char* getUndefined<const char*>() {
+ return "";
+}
// for all signed integers
template <typename T>
diff --git a/searchlib/src/vespa/searchlib/attribute/attributevector.cpp b/searchlib/src/vespa/searchlib/attribute/attributevector.cpp
index 1a2c8c43b94..f4ab447ed51 100644
--- a/searchlib/src/vespa/searchlib/attribute/attributevector.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/attributevector.cpp
@@ -353,6 +353,8 @@ AttributeVector::load(vespalib::Executor * executor) {
bool loaded = onLoad(executor);
if (loaded) {
commit();
+ incGeneration();
+ updateStat(true);
}
_loaded = loaded;
return _loaded;
@@ -451,19 +453,6 @@ AttributeVector::set_reserved_doc_values()
return;
}
clearDoc(docId);
- if (hasMultiValue()) {
- if (isFloatingPointType()) {
- auto * vec = dynamic_cast<FloatingPointAttribute *>(this);
- bool appendedUndefined = vec->append(0, attribute::getUndefined<double>(), 1);
- assert(appendedUndefined);
- (void) appendedUndefined;
- } else if (isStringType()) {
- auto * vec = dynamic_cast<StringAttribute *>(this);
- bool appendedUndefined = vec->append(0, StringAttribute::defaultValue(), 1);
- assert(appendedUndefined);
- (void) appendedUndefined;
- }
- }
commit();
}
diff --git a/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.cpp b/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.cpp
index eeaa3e9539f..c1345b4f770 100644
--- a/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.cpp
@@ -93,6 +93,7 @@ EnumeratedLoader::build_dictionary()
{
_store.get_dictionary().build(_indexes);
release_enum_indexes();
+ _store.setup_default_value_ref();
}
EnumeratedPostingsLoader::EnumeratedPostingsLoader(IEnumStore& store)
@@ -131,6 +132,13 @@ EnumeratedPostingsLoader::build_dictionary()
_store.get_dictionary().build_with_payload(_indexes, _posting_indexes);
release_enum_indexes();
EntryRefVector().swap(_posting_indexes);
+ _store.setup_default_value_ref();
+}
+
+void
+EnumeratedPostingsLoader::build_empty_dictionary()
+{
+ _store.setup_default_value_ref();
}
}
diff --git a/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.h b/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.h
index 2a72fcac628..937ceb91628 100644
--- a/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.h
+++ b/searchlib/src/vespa/searchlib/attribute/enum_store_loaders.h
@@ -85,6 +85,7 @@ public:
void set_ref_count(Index idx, uint32_t ref_count);
vespalib::ArrayRef<EntryRef> initialize_empty_posting_indexes();
void build_dictionary();
+ void build_empty_dictionary();
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/enumattribute.h b/searchlib/src/vespa/searchlib/attribute/enumattribute.h
index f0ff23a06b4..4753dbe65f9 100644
--- a/searchlib/src/vespa/searchlib/attribute/enumattribute.h
+++ b/searchlib/src/vespa/searchlib/attribute/enumattribute.h
@@ -50,13 +50,12 @@ protected:
/*
* Iterate through the change vector and find new unique values.
- * Perform compaction if necessary and insert the new unique values into the EnumStore.
+ * Insert the new unique values into the EnumStore.
*/
void insertNewUniqueValues(EnumStoreBatchUpdater& updater);
virtual void considerAttributeChange(const Change & c, EnumStoreBatchUpdater & inserter) = 0;
vespalib::MemoryUsage getEnumStoreValuesMemoryUsage() const override;
void populate_address_space_usage(AddressSpaceUsage& usage) const override;
- void cache_change_data_entry_ref(const Change& c) const;
public:
EnumAttribute(const vespalib::string & baseFileName, const AttributeVector::Config & cfg);
~EnumAttribute();
diff --git a/searchlib/src/vespa/searchlib/attribute/enumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/enumattribute.hpp
index c5188b89129..66d555df3cb 100644
--- a/searchlib/src/vespa/searchlib/attribute/enumattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/enumattribute.hpp
@@ -15,7 +15,7 @@ EnumAttribute<B>::
EnumAttribute(const vespalib::string &baseFileName,
const AttributeVector::Config &cfg)
: B(baseFileName, cfg),
- _enumStore(cfg.fastSearch(), cfg.get_dictionary_config(), this->get_memory_allocator())
+ _enumStore(cfg.fastSearch(), cfg.get_dictionary_config(), this->get_memory_allocator(), this->_defaultValue._data.raw())
{
this->setEnum(true);
}
@@ -50,6 +50,7 @@ void EnumAttribute<B>::load_enum_store(LoadedVector& loaded)
loader.set_ref_count_for_last_value(prevRefCount);
}
loader.build_dictionary();
+ _enumStore.setup_default_value_ref();
}
}
@@ -85,15 +86,4 @@ EnumAttribute<B>::populate_address_space_usage(AddressSpaceUsage& usage) const
usage.set(AddressSpaceComponents::enum_store, _enumStore.get_values_address_space_usage());
}
-template <typename B>
-void
-EnumAttribute<B>::cache_change_data_entry_ref(const Change& c) const
-{
- EnumIndex new_idx;
- _enumStore.find_index(c._data.raw(), new_idx);
- c.set_entry_ref(new_idx.ref());
-}
-
} // namespace search
-
-
diff --git a/searchlib/src/vespa/searchlib/attribute/enumstore.h b/searchlib/src/vespa/searchlib/attribute/enumstore.h
index 266437fafa1..f6467194d74 100644
--- a/searchlib/src/vespa/searchlib/attribute/enumstore.h
+++ b/searchlib/src/vespa/searchlib/attribute/enumstore.h
@@ -28,6 +28,9 @@ namespace search {
* It uses an instance of vespalib::datastore::UniqueStore to store the actual values.
* It also exposes the dictionary used for fast lookups into the set of unique values.
*
+ * The default value is always present except for a short time window
+ * during attribute vector load.
+ *
* @tparam EntryType The type of the entries/values stored.
* It has special handling of type 'const char *' for strings.
*/
@@ -55,6 +58,8 @@ private:
ComparatorType _comparator;
ComparatorType _foldedComparator;
enumstore::EnumStoreCompactionSpec _compaction_spec;
+ EntryType _default_value;
+ AtomicIndex _default_value_ref;
EnumStoreT(const EnumStoreT & rhs) = delete;
EnumStoreT & operator=(const EnumStoreT & rhs) = delete;
@@ -75,7 +80,7 @@ private:
std::unique_ptr<EntryComparator> allocate_optionally_folded_comparator(bool folded) const;
ComparatorType make_optionally_folded_comparator(bool folded) const;
public:
- EnumStoreT(bool has_postings, const search::DictionaryConfig& dict_cfg, std::shared_ptr<vespalib::alloc::MemoryAllocator> memory_allocator);
+ EnumStoreT(bool has_postings, const search::DictionaryConfig& dict_cfg, std::shared_ptr<vespalib::alloc::MemoryAllocator> memory_allocator, EntryType default_value);
EnumStoreT(bool has_postings, const search::DictionaryConfig & dict_cfg);
~EnumStoreT() override;
@@ -201,6 +206,9 @@ public:
bool find_index(EntryType value, Index& idx) const;
void free_unused_values() override;
void free_unused_values(IndexList to_remove);
+ void clear_default_value_ref() override;
+ void setup_default_value_ref() override;
+ const AtomicIndex& get_default_value_ref() const noexcept { return _default_value_ref; }
vespalib::MemoryUsage update_stat(const CompactionStrategy& compaction_strategy) override;
std::unique_ptr<EnumIndexRemapper> consider_compact_values(const CompactionStrategy& compaction_strategy) override;
std::unique_ptr<EnumIndexRemapper> compact_worst_values(CompactionSpec compaction_spec, const CompactionStrategy& compaction_strategy) override;
diff --git a/searchlib/src/vespa/searchlib/attribute/enumstore.hpp b/searchlib/src/vespa/searchlib/attribute/enumstore.hpp
index bc767a296eb..c0eebee8e94 100644
--- a/searchlib/src/vespa/searchlib/attribute/enumstore.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/enumstore.hpp
@@ -17,6 +17,7 @@
#include <vespa/vespalib/datastore/unique_store.hpp>
#include <vespa/vespalib/datastore/unique_store_string_allocator.hpp>
#include <vespa/vespalib/util/array.hpp>
+#include <vespa/searchcommon/common/undefinedvalues.h>
#include <vespa/searchlib/util/bufferwriter.h>
#include <vespa/vespalib/datastore/compaction_strategy.h>
@@ -72,23 +73,26 @@ EnumStoreT<EntryT>::load_unique_value(const void* src, size_t available, Index&
}
template <typename EntryT>
-EnumStoreT<EntryT>::EnumStoreT(bool has_postings, const DictionaryConfig& dict_cfg, std::shared_ptr<vespalib::alloc::MemoryAllocator> memory_allocator)
+EnumStoreT<EntryT>::EnumStoreT(bool has_postings, const DictionaryConfig& dict_cfg, std::shared_ptr<vespalib::alloc::MemoryAllocator> memory_allocator, EntryType default_value)
: _store(std::move(memory_allocator)),
_dict(),
_is_folded(dict_cfg.getMatch() == DictionaryConfig::Match::UNCASED),
_comparator(_store.get_data_store()),
_foldedComparator(make_optionally_folded_comparator(is_folded())),
- _compaction_spec()
+ _compaction_spec(),
+ _default_value(default_value),
+ _default_value_ref()
{
_store.set_dictionary(make_enum_store_dictionary(*this, has_postings, dict_cfg,
allocate_comparator(),
allocate_optionally_folded_comparator(is_folded())));
_dict = static_cast<IEnumStoreDictionary*>(&_store.get_dictionary());
+ setup_default_value_ref();
}
template <typename EntryT>
EnumStoreT<EntryT>::EnumStoreT(bool has_postings, const DictionaryConfig& dict_cfg)
- : EnumStoreT<EntryT>(has_postings, dict_cfg, {})
+ : EnumStoreT<EntryT>(has_postings, dict_cfg, {}, attribute::getUndefined<EntryType>())
{
}
@@ -215,6 +219,33 @@ EnumStoreT<EntryT>::insert(EntryType value)
return _store.add(value).ref();
}
+
+template <typename EntryT>
+void
+EnumStoreT<EntryT>::clear_default_value_ref()
+{
+ auto ref = _default_value_ref.load_relaxed();
+ if (ref.valid()) {
+ auto updater = make_batch_updater();
+ updater.dec_ref_count(ref);
+ _default_value_ref.store_relaxed(Index());
+ updater.commit();
+ }
+}
+
+template <typename EntryT>
+void
+EnumStoreT<EntryT>::setup_default_value_ref()
+{
+ if (!_default_value_ref.load_relaxed().valid()) {
+ auto updater = make_batch_updater();
+ auto ref = updater.insert(_default_value);
+ updater.inc_ref_count(ref);
+ _default_value_ref.store_relaxed(ref);
+ updater.commit();
+ }
+}
+
template <typename EntryT>
vespalib::MemoryUsage
EnumStoreT<EntryT>::update_stat(const CompactionStrategy& compaction_strategy)
@@ -236,7 +267,14 @@ template <typename EntryT>
std::unique_ptr<IEnumStore::EnumIndexRemapper>
EnumStoreT<EntryT>::compact_worst_values(CompactionSpec compaction_spec, const CompactionStrategy& compaction_strategy)
{
- return _store.compact_worst(compaction_spec, compaction_strategy);
+ auto remapper = _store.compact_worst(compaction_spec, compaction_strategy);
+ if (remapper) {
+ auto ref = _default_value_ref.load_relaxed();
+ if (ref.valid() && remapper->get_entry_ref_filter().has(ref)) {
+ _default_value_ref.store_release(remapper->remap(ref));
+ }
+ }
+ return remapper;
}
template <typename EntryT>
diff --git a/searchlib/src/vespa/searchlib/attribute/i_enum_store.h b/searchlib/src/vespa/searchlib/attribute/i_enum_store.h
index 2157db3e5ed..aa9fd549b60 100644
--- a/searchlib/src/vespa/searchlib/attribute/i_enum_store.h
+++ b/searchlib/src/vespa/searchlib/attribute/i_enum_store.h
@@ -74,6 +74,8 @@ public:
virtual std::unique_ptr<Enumerator> make_enumerator() = 0;
virtual std::unique_ptr<vespalib::datastore::EntryComparator> allocate_comparator() const = 0;
+ virtual void clear_default_value_ref() = 0;
+ virtual void setup_default_value_ref() = 0;
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multinumericenumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multinumericenumattribute.hpp
index edfea23f48d..59c1216829d 100644
--- a/searchlib/src/vespa/searchlib/attribute/multinumericenumattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multinumericenumattribute.hpp
@@ -97,6 +97,10 @@ MultiValueNumericEnumAttribute<B, M>::onLoad(vespalib::Executor *)
return false;
}
+ this->_enumStore.clear_default_value_ref();
+ this->commit();
+ this->incGeneration();
+
this->setCreateSerialNum(attrReader.getCreateSerialNum());
if (attrReader.getEnumerated()) {
diff --git a/searchlib/src/vespa/searchlib/attribute/multistringattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multistringattribute.hpp
index a63862126fa..7b11fcd59f4 100644
--- a/searchlib/src/vespa/searchlib/attribute/multistringattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multistringattribute.hpp
@@ -42,7 +42,6 @@ MultiValueStringAttributeT<B, M>::freezeEnumDictionary()
this->getEnumStore().freeze_dictionary();
}
-
template <typename B, typename M>
std::unique_ptr<attribute::SearchContext>
MultiValueStringAttributeT<B, M>::getSearch(QueryTermSimpleUP qTerm,
diff --git a/searchlib/src/vespa/searchlib/attribute/postinglistattribute.cpp b/searchlib/src/vespa/searchlib/attribute/postinglistattribute.cpp
index 6ef3b575c3e..01e68949f92 100644
--- a/searchlib/src/vespa/searchlib/attribute/postinglistattribute.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/postinglistattribute.cpp
@@ -49,6 +49,7 @@ PostingListAttributeBase<P>::handle_load_posting_lists_and_update_enum_store(enu
PostingChange<P> postings;
const auto& loaded_enums = loader.get_loaded_enums();
if (loaded_enums.empty()) {
+ loader.build_empty_dictionary();
return;
}
uint32_t preve = 0;
diff --git a/searchlib/src/vespa/searchlib/attribute/singleenumattribute.h b/searchlib/src/vespa/searchlib/attribute/singleenumattribute.h
index aac9a7b5416..7f36238ec6a 100644
--- a/searchlib/src/vespa/searchlib/attribute/singleenumattribute.h
+++ b/searchlib/src/vespa/searchlib/attribute/singleenumattribute.h
@@ -67,14 +67,14 @@ protected:
void considerAttributeChange(const Change & c, EnumStoreBatchUpdater & inserter) override;
// implemented by single value numeric enum attribute.
- virtual void considerUpdateAttributeChange(const Change & c) { (void) c; }
+ virtual void considerUpdateAttributeChange(DocId, const Change&) { }
virtual void considerArithmeticAttributeChange(const Change & c, EnumStoreBatchUpdater & inserter) { (void) c; (void) inserter; }
virtual void applyValueChanges(EnumStoreBatchUpdater& updater) ;
virtual void applyArithmeticValueChange(const Change& c, EnumStoreBatchUpdater& updater) {
(void) c; (void) updater;
}
- void updateEnumRefCounts(const Change& c, EnumIndex newIdx, EnumIndex oldIdx, EnumStoreBatchUpdater& updater);
+ void updateEnumRefCounts(DocId doc, EnumIndex newIdx, EnumIndex oldIdx, EnumStoreBatchUpdater& updater);
virtual void freezeEnumDictionary() {
this->getEnumStore().freeze_dictionary();
diff --git a/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp
index f4f2b777abd..95976609940 100644
--- a/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singleenumattribute.hpp
@@ -146,7 +146,7 @@ SingleValueEnumAttribute<B>::considerUpdateAttributeChange(const Change & c, Enu
} else {
c.set_entry_ref(idx.ref());
}
- considerUpdateAttributeChange(c); // for numeric
+ considerUpdateAttributeChange(c._doc, c); // for numeric
}
template <typename B>
@@ -158,9 +158,7 @@ SingleValueEnumAttribute<B>::considerAttributeChange(const Change & c, EnumStore
} else if (c._type >= ChangeBase::ADD && c._type <= ChangeBase::DIV) {
considerArithmeticAttributeChange(c, inserter); // for numeric
} else if (c._type == ChangeBase::CLEARDOC) {
- Change clearDoc(this->_defaultValue);
- clearDoc._doc = c._doc;
- considerUpdateAttributeChange(clearDoc, inserter);
+ considerUpdateAttributeChange(c._doc, this->_defaultValue);
}
}
@@ -175,7 +173,7 @@ SingleValueEnumAttribute<B>::applyUpdateValueChange(const Change& c, EnumStoreBa
} else {
this->_enumStore.find_index(c._data.raw(), newIdx);
}
- updateEnumRefCounts(c, newIdx, oldIdx, updater);
+ updateEnumRefCounts(c._doc, newIdx, oldIdx, updater);
}
template <typename B>
@@ -183,30 +181,26 @@ void
SingleValueEnumAttribute<B>::applyValueChanges(EnumStoreBatchUpdater& updater)
{
ValueModifier valueGuard(this->getValueModifier());
- // This avoids searching for the defaultValue in the enum store for each CLEARDOC in the change vector.
- this->cache_change_data_entry_ref(this->_defaultValue);
for (const auto& change : this->_changes.getInsertOrder()) {
if (change._type == ChangeBase::UPDATE) {
applyUpdateValueChange(change, updater);
} else if (change._type >= ChangeBase::ADD && change._type <= ChangeBase::DIV) {
applyArithmeticValueChange(change, updater);
} else if (change._type == ChangeBase::CLEARDOC) {
- Change clearDoc(this->_defaultValue);
- clearDoc._doc = change._doc;
- applyUpdateValueChange(clearDoc, updater);
+ EnumIndex oldIdx = _enumIndices[change._doc].load_relaxed();
+ EnumIndex newIdx = this->_enumStore.get_default_value_ref().load_relaxed();
+ updateEnumRefCounts(change._doc, newIdx, oldIdx, updater);
}
}
- // We must clear the cached entry ref as the defaultValue might be located in another data buffer on later invocations.
- this->_defaultValue.clear_entry_ref();
}
template <typename B>
void
-SingleValueEnumAttribute<B>::updateEnumRefCounts(const Change& c, EnumIndex newIdx, EnumIndex oldIdx,
+SingleValueEnumAttribute<B>::updateEnumRefCounts(DocId doc, EnumIndex newIdx, EnumIndex oldIdx,
EnumStoreBatchUpdater& updater)
{
updater.inc_ref_count(newIdx);
- _enumIndices[c._doc].store_release(newIdx);
+ _enumIndices[doc].store_release(newIdx);
if (oldIdx.valid()) {
updater.dec_ref_count(oldIdx);
}
diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp
index a105d980986..c75ee0aacb5 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp
@@ -134,8 +134,9 @@ SingleValueNumericAttribute<B>::onLoad(vespalib::Executor *)
PrimitiveReader<T> attrReader(*this);
bool ok(attrReader.getHasLoadData());
- if (!ok)
+ if (!ok) {
return false;
+ }
this->setCreateSerialNum(attrReader.getCreateSerialNum());
diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.h b/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.h
index 5b0e1c6131e..4eeb6ceda57 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.h
+++ b/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.h
@@ -43,7 +43,7 @@ private:
protected:
// from SingleValueEnumAttribute
- void considerUpdateAttributeChange(const Change & c) override;
+ void considerUpdateAttributeChange(DocId doc, const Change & c) override;
void considerArithmeticAttributeChange(const Change & c, EnumStoreBatchUpdater & inserter) override;
void applyArithmeticValueChange(const Change& c, EnumStoreBatchUpdater& updater) override;
diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp
index 52ea0a53533..b840a0516b2 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp
@@ -15,9 +15,9 @@ namespace search {
template <typename B>
void
-SingleValueNumericEnumAttribute<B>::considerUpdateAttributeChange(const Change & c)
+SingleValueNumericEnumAttribute<B>::considerUpdateAttributeChange(DocId doc, const Change & c)
{
- _currDocValues[c._doc] = c._data.get();
+ _currDocValues[doc] = c._data.get();
}
template <typename B>
@@ -53,7 +53,7 @@ SingleValueNumericEnumAttribute<B>::applyArithmeticValueChange(const Change& c,
T newValue = this->template applyArithmetic<T, typename Change::DataType>(get(c._doc), c._data.getArithOperand(), c._type);
this->_enumStore.find_index(newValue, newIdx);
- this->updateEnumRefCounts(c, newIdx, oldIdx, updater);
+ this->updateEnumRefCounts(c._doc, newIdx, oldIdx, updater);
}
template <typename B>
@@ -117,6 +117,10 @@ SingleValueNumericEnumAttribute<B>::onLoad(vespalib::Executor *)
return false;
}
+ this->_enumStore.clear_default_value_ref();
+ this->commit();
+ this->incGeneration();
+
this->setCreateSerialNum(attrReader.getCreateSerialNum());
if (attrReader.getEnumerated()) {
diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp
index de4a7157dae..e353d03a9e8 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp
@@ -89,8 +89,6 @@ SingleValueNumericPostingAttribute<B>::applyValueChanges(EnumStoreBatchUpdater&
// used to make sure several arithmetic operations on the same document in a single commit works
std::map<DocId, EnumIndex> currEnumIndices;
- // This avoids searching for the defaultValue in the enum store for each CLEARDOC in the change vector.
- this->cache_change_data_entry_ref(this->_defaultValue);
for (const auto& change : this->_changes.getInsertOrder()) {
auto enumIter = currEnumIndices.find(change._doc);
EnumIndex oldIdx;
@@ -111,13 +109,9 @@ SingleValueNumericPostingAttribute<B>::applyValueChanges(EnumStoreBatchUpdater&
currEnumIndices[change._doc] = newIdx;
}
} else if(change._type == ChangeBase::CLEARDOC) {
- Change clearDoc(this->_defaultValue);
- clearDoc._doc = change._doc;
- applyUpdateValueChange(clearDoc, enumStore, currEnumIndices);
+ currEnumIndices[change._doc] = enumStore.get_default_value_ref().load_relaxed();
}
}
- // We must clear the cached entry ref as the defaultValue might be located in another data buffer on later invocations.
- this->_defaultValue.clear_entry_ref();
makePostingChange(enumStore.get_comparator(), currEnumIndices, changePost);
diff --git a/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp
index 82a4393fc91..69fe6435a03 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp
@@ -40,7 +40,6 @@ SingleValueStringAttributeT<B>::freezeEnumDictionary()
this->getEnumStore().freeze_dictionary();
}
-
template <typename B>
std::unique_ptr<attribute::SearchContext>
SingleValueStringAttributeT<B>::getSearch(QueryTermSimpleUP qTerm,
diff --git a/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp
index 1ec9b54a73b..5b5214f6d3e 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp
@@ -98,8 +98,6 @@ SingleValueStringPostingAttributeT<B>::applyValueChanges(EnumStoreBatchUpdater&
// used to make sure several arithmetic operations on the same document in a single commit works
std::map<DocId, EnumIndex> currEnumIndices;
- // This avoids searching for the defaultValue in the enum store for each CLEARDOC in the change vector.
- this->cache_change_data_entry_ref(this->_defaultValue);
for (const auto& change : this->_changes.getInsertOrder()) {
auto enumIter = currEnumIndices.find(change._doc);
EnumIndex oldIdx;
@@ -111,12 +109,9 @@ SingleValueStringPostingAttributeT<B>::applyValueChanges(EnumStoreBatchUpdater&
if (change._type == ChangeBase::UPDATE) {
applyUpdateValueChange(change, enumStore, currEnumIndices);
} else if (change._type == ChangeBase::CLEARDOC) {
- this->_defaultValue._doc = change._doc;
- applyUpdateValueChange(this->_defaultValue, enumStore, currEnumIndices);
+ currEnumIndices[change._doc] = enumStore.get_default_value_ref().load_relaxed();
}
}
- // We must clear the cached entry ref as the defaultValue might be located in another data buffer on later invocations.
- this->_defaultValue.clear_entry_ref();
makePostingChange(enumStore.get_folded_comparator(), dictionary, currEnumIndices, changePost);
diff --git a/searchlib/src/vespa/searchlib/attribute/stringbase.cpp b/searchlib/src/vespa/searchlib/attribute/stringbase.cpp
index 80967affaa7..b37318d470e 100644
--- a/searchlib/src/vespa/searchlib/attribute/stringbase.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/stringbase.cpp
@@ -223,6 +223,10 @@ StringAttribute::onLoad(vespalib::Executor *)
return false;
}
+ getEnumStoreBase()->clear_default_value_ref();
+ commit();
+ incGeneration();
+
setCreateSerialNum(attrReader.getCreateSerialNum());
assert(attrReader.getEnumerated());
diff --git a/searchlib/src/vespa/searchlib/attribute/stringbase.h b/searchlib/src/vespa/searchlib/attribute/stringbase.h
index 5c6bf3c6b6a..98a3316947b 100644
--- a/searchlib/src/vespa/searchlib/attribute/stringbase.h
+++ b/searchlib/src/vespa/searchlib/attribute/stringbase.h
@@ -62,7 +62,7 @@ protected:
using ChangeVector = ChangeVectorT<Change>;
using EnumEntryType = const char*;
ChangeVector _changes;
- Change _defaultValue;
+ const Change _defaultValue;
bool onLoad(vespalib::Executor *executor) override;
bool onLoadEnumerated(ReaderBase &attrReader);
diff --git a/searchlib/src/vespa/searchlib/query/query_term_simple.h b/searchlib/src/vespa/searchlib/query/query_term_simple.h
index 74728ab1f2e..a79e33dba32 100644
--- a/searchlib/src/vespa/searchlib/query/query_term_simple.h
+++ b/searchlib/src/vespa/searchlib/query/query_term_simple.h
@@ -23,7 +23,8 @@ public:
SUFFIXTERM = 4,
REGEXP = 5,
GEO_LOCATION = 6,
- FUZZYTERM = 7
+ FUZZYTERM = 7,
+ NEAREST_NEIGHBOR = 8
};
template <typename N>
@@ -65,6 +66,7 @@ public:
bool isRegex() const { return (_type == Type::REGEXP); }
bool isGeoLoc() const { return (_type == Type::GEO_LOCATION); }
bool isFuzzy() const { return (_type == Type::FUZZYTERM); }
+ bool is_nearest_neighbor() const noexcept { return (_type == Type::NEAREST_NEIGHBOR); }
bool empty() const { return _term.empty(); }
virtual void visitMembers(vespalib::ObjectVisitor &visitor) const;
vespalib::string getClassName() const;
diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
index 27f9870dc18..c71b838fb37 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
@@ -1,6 +1,7 @@
# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
vespa_add_library(searchlib_query_streaming OBJECT
SOURCES
+ nearest_neighbor_query_node.cpp
query.cpp
querynode.cpp
querynoderesultbase.cpp
diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp
new file mode 100644
index 00000000000..d1c37cd6dcd
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp
@@ -0,0 +1,36 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "nearest_neighbor_query_node.h"
+
+namespace search::streaming {
+
+NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold)
+ : QueryTerm(std::move(resultBase), term, index, Type::NEAREST_NEIGHBOR),
+ _distance_threshold(distance_threshold),
+ _raw_score()
+{
+ setUniqueId(id);
+ setWeight(weight);
+}
+
+NearestNeighborQueryNode::~NearestNeighborQueryNode() = default;
+
+bool
+NearestNeighborQueryNode::evaluate() const
+{
+ return _raw_score.has_value();
+}
+
+void
+NearestNeighborQueryNode::reset()
+{
+ _raw_score.reset();
+}
+
+NearestNeighborQueryNode*
+NearestNeighborQueryNode::as_nearest_neighbor_query_node() noexcept
+{
+ return this;
+}
+
+}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h
new file mode 100644
index 00000000000..0beb130c53d
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h
@@ -0,0 +1,35 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "queryterm.h"
+#include <optional>
+
+namespace search::streaming {
+
+/*
+ * Nearest neighbor query node.
+ */
+class NearestNeighborQueryNode: public QueryTerm {
+private:
+ double _distance_threshold;
+ // When this value is set it also indicates a match
+ std::optional<double> _raw_score;
+
+public:
+ NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold);
+ NearestNeighborQueryNode(const NearestNeighborQueryNode &) = delete;
+ NearestNeighborQueryNode & operator = (const NearestNeighborQueryNode &) = delete;
+ NearestNeighborQueryNode(NearestNeighborQueryNode &&) = delete;
+ NearestNeighborQueryNode & operator = (NearestNeighborQueryNode &&) = delete;
+ ~NearestNeighborQueryNode() override;
+ bool evaluate() const override;
+ void reset() override;
+ NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept override;
+ const vespalib::string& get_query_tensor_name() const { return getTermString(); }
+ double get_distance_threshold() const { return _distance_threshold; }
+ void set_raw_score(double value) { _raw_score = value; }
+ const std::optional<double>& get_raw_score() const noexcept { return _raw_score; }
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
index 6d59886a4f5..226cb92c894 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "query.h"
+#include "nearest_neighbor_query_node.h"
#include <vespa/searchlib/parsequery/stackdumpiterator.h>
#include <charconv>
#include <vespa/log/log.h>
@@ -77,6 +78,9 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
queryRep.getIndexName(),
QueryTerm::Type::GEO_LOCATION);
break;
+ case ParseItem::ITEM_NEAREST_NEIGHBOR:
+ qn = build_nearest_neighbor_query_node(factory, queryRep);
+ break;
case ParseItem::ITEM_NUMTERM:
case ParseItem::ITEM_TERM:
case ParseItem::ITEM_PREFIXTERM:
@@ -191,4 +195,20 @@ const HitList & QueryNode::evaluateHits(HitList & hl) const
return hl;
}
+std::unique_ptr<QueryNode>
+QueryNode::build_nearest_neighbor_query_node(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& query_rep)
+{
+ vespalib::stringref query_tensor_name = query_rep.getTerm();
+ vespalib::stringref field_name = query_rep.getIndexName();
+ int32_t id = query_rep.getUniqueId();
+ search::query::Weight weight = query_rep.GetWeight();
+ double distance_threshold = query_rep.getDistanceThreshold();
+ return std::make_unique<NearestNeighborQueryNode>(factory.create(),
+ query_tensor_name,
+ field_name,
+ id,
+ weight,
+ distance_threshold);
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.h b/searchlib/src/vespa/searchlib/query/streaming/querynode.h
index 574a3c16ca3..c3fa2b63f69 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.h
@@ -28,6 +28,7 @@ using ConstQueryTermList = std::vector<const QueryTerm *>;
*/
class QueryNode
{
+ static std::unique_ptr<QueryNode> build_nearest_neighbor_query_node(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep);
public:
using UP = std::unique_ptr<QueryNode>;
@@ -54,7 +55,7 @@ class QueryNode
virtual size_t depth() const { return 1; }
/// Return the width of this tree.
virtual size_t width() const { return 1; }
- static UP Build(const QueryNode * parent, const QueryNodeResultFactory & org, SimpleQueryStackDumpIterator & queryRep, bool allowRewrite);
+ static UP Build(const QueryNode * parent, const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator & queryRep, bool allowRewrite);
};
/// A list conating the QuerNode objects. With copy/assignment.
diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp
index 83f4410a520..11557bf1dcc 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp
@@ -92,4 +92,10 @@ void QueryTerm::add(unsigned pos, unsigned context, uint32_t elemId, int32_t wei
_hitList.emplace_back(pos, context, elemId, weight_);
}
+NearestNeighborQueryNode*
+QueryTerm::as_nearest_neighbor_query_node() noexcept
+{
+ return nullptr;
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h
index dd9f56b11e1..51987225692 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h
@@ -12,6 +12,8 @@
namespace search::streaming {
+class NearestNeighborQueryNode;
+
/**
This is a leaf in the Query tree. All terms are leafs.
A QueryTerm has the index for where to find the term. The term is a string,
@@ -57,7 +59,7 @@ public:
QueryTerm & operator = (const QueryTerm &) = delete;
QueryTerm(QueryTerm &&) = delete;
QueryTerm & operator = (QueryTerm &&) = delete;
- ~QueryTerm();
+ ~QueryTerm() override;
bool evaluate() const override;
const HitList & evaluateHits(HitList & hl) const override;
void reset() override;
@@ -87,6 +89,7 @@ public:
const string & getIndex() const override { return _index; }
void setFuzzyMaxEditDistance(uint32_t fuzzyMaxEditDistance) { _fuzzyMaxEditDistance = fuzzyMaxEditDistance; }
void setFuzzyPrefixLength(uint32_t fuzzyPrefixLength) { _fuzzyPrefixLength = fuzzyPrefixLength; }
+ virtual NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept;
protected:
using QueryNodeResultBaseContainer = std::unique_ptr<QueryNodeResultBase>;
string _index;
diff --git a/security-utils/src/main/java/com/yahoo/security/SideChannelSafe.java b/security-utils/src/main/java/com/yahoo/security/SideChannelSafe.java
index 1f160d94c6a..bd085f6f624 100644
--- a/security-utils/src/main/java/com/yahoo/security/SideChannelSafe.java
+++ b/security-utils/src/main/java/com/yahoo/security/SideChannelSafe.java
@@ -46,7 +46,7 @@ public class SideChannelSafe {
// differed in any byte compared between the two arrays.
byte accu = 0;
for (int i = 0; i < lhs.length; ++i) {
- accu |= (lhs[i] ^ rhs[i]);
+ accu |= (byte)(lhs[i] ^ rhs[i]);
}
return (accu == 0);
}
diff --git a/security-utils/src/test/java/com/yahoo/security/SharedKeyTest.java b/security-utils/src/test/java/com/yahoo/security/SharedKeyTest.java
index 26627e9a5fa..90b8beb461f 100644
--- a/security-utils/src/test/java/com/yahoo/security/SharedKeyTest.java
+++ b/security-utils/src/test/java/com/yahoo/security/SharedKeyTest.java
@@ -285,12 +285,12 @@ public class SharedKeyTest {
String plaintext = "...hello world?";
byte[] encrypted = streamEncryptString(plaintext, myShared);
// Corrupt MAC tag in ciphertext
- encrypted[encrypted.length - 1] ^= 0x80;
+ encrypted[encrypted.length - 1] ^= (byte)0x80;
// We don't necessarily know _which_ exception is thrown, but one _should_ be thrown!
assertThrows(Exception.class, () -> doOutputStreamCipherDecrypt(myShared, encrypted));
// Also try with corrupted ciphertext (pre MAC tag)
- encrypted[encrypted.length - 1] ^= 0x80; // Flip MAC bit back to correct state
- encrypted[encrypted.length - 17] ^= 0x80; // Pre 128-bit MAC tag
+ encrypted[encrypted.length - 1] ^= (byte)0x80; // Flip MAC bit back to correct state
+ encrypted[encrypted.length - 17] ^= (byte)0x80; // Pre 128-bit MAC tag
assertThrows(Exception.class, () -> doOutputStreamCipherDecrypt(myShared, encrypted));
}
diff --git a/storage/src/tests/persistence/persistencetestutils.h b/storage/src/tests/persistence/persistencetestutils.h
index 94ae7b9fb53..e60260f3ee8 100644
--- a/storage/src/tests/persistence/persistencetestutils.h
+++ b/storage/src/tests/persistence/persistencetestutils.h
@@ -150,6 +150,18 @@ public:
_replySender, MockBucketLock::make(bucket, _mock_bucket_locks), std::move(cmd));
}
+ template <typename T>
+ requires std::is_base_of_v<api::StorageReply, T>
+ [[nodiscard]] std::shared_ptr<T>
+ fetch_single_reply(MessageTracker::UP tracker) {
+ if (tracker && tracker->hasReply()) {
+ tracker->sendReply(); // Forward to queue so we can fetch it below
+ }
+ std::shared_ptr<api::StorageMessage> msg;
+ _replySender.queue.getNext(msg, 60s);
+ return std::dynamic_pointer_cast<T>(msg);
+ }
+
api::ReturnCode
fetchResult(const MessageTracker::UP & tracker) {
if (tracker) {
diff --git a/storage/src/tests/persistence/testandsettest.cpp b/storage/src/tests/persistence/testandsettest.cpp
index 5be1c7cd92a..1aa359de634 100644
--- a/storage/src/tests/persistence/testandsettest.cpp
+++ b/storage/src/tests/persistence/testandsettest.cpp
@@ -1,16 +1,16 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
// @author Vegard Sjonfjell
-#include <vespa/storage/persistence/persistencehandler.h>
#include <tests/persistence/persistencetestutils.h>
#include <vespa/document/test/make_document_bucket.h>
-#include <vespa/documentapi/messagebus/messages/testandsetcondition.h>
#include <vespa/document/fieldvalue/fieldvalues.h>
#include <vespa/document/update/documentupdate.h>
#include <vespa/document/update/assignvalueupdate.h>
#include <vespa/document/fieldset/fieldsets.h>
+#include <vespa/documentapi/messagebus/messages/testandsetcondition.h>
#include <vespa/persistence/spi/test.h>
#include <vespa/persistence/spi/persistenceprovider.h>
#include <vespa/persistence/spi/docentry.h>
+#include <vespa/storage/persistence/persistencehandler.h>
#include <functional>
using std::unique_ptr;
@@ -19,6 +19,7 @@ using std::shared_ptr;
using storage::spi::test::makeSpiBucket;
using document::test::makeDocumentBucket;
using document::StringFieldValue;
+using documentapi::TestAndSetCondition;
using namespace ::testing;
namespace storage {
@@ -34,15 +35,18 @@ struct TestAndSetTest : PersistenceTestUtils {
const StringFieldValue OLD_CONTENT{"Some old content"};
const StringFieldValue NEW_CONTENT{"Freshly pressed and squeezed content"};
const document::Bucket BUCKET = makeDocumentBucket(BUCKET_ID);
+ const TestAndSetCondition MATCHING_CONDITION{"testdoctype1.hstringval=\"*woofy dog*\""};
unique_ptr<PersistenceHandler> persistenceHandler;
const AsyncHandler * asyncHandler;
+ const SimpleMessageHandler* simple_handler;
shared_ptr<document::Document> testDoc;
document::DocumentId testDocId;
TestAndSetTest()
: persistenceHandler(),
- asyncHandler(nullptr)
+ asyncHandler(nullptr),
+ simple_handler(nullptr)
{}
void SetUp() override {
@@ -54,6 +58,7 @@ struct TestAndSetTest : PersistenceTestUtils {
testDoc = createTestDocument();
testDocId = testDoc->getId();
asyncHandler = &_persistenceHandler->asyncHandler();
+ simple_handler = &_persistenceHandler->simpleMessageHandler();
}
void TearDown() override {
@@ -68,6 +73,8 @@ struct TestAndSetTest : PersistenceTestUtils {
document::Document::SP retrieveTestDocument();
void setTestCondition(api::TestAndSetCommand & command);
void putTestDocument(bool matchingHeader, api::Timestamp timestamp);
+ std::shared_ptr<api::GetReply> invoke_conditional_get();
+ void feed_remove_entry_with_timestamp(api::Timestamp timestamp);
void assertTestDocumentFoundAndMatchesContent(const document::FieldValue & value);
static std::string expectedDocEntryString(
@@ -247,6 +254,59 @@ TEST_F(TestAndSetTest, conditional_put_to_non_existing_document_should_fail) {
EXPECT_EQ("", dumpBucket(BUCKET_ID));
}
+TEST_F(TestAndSetTest, conditional_get_returns_doc_metadata_on_match) {
+ const api::Timestamp timestamp = 12345;
+ putTestDocument(true, timestamp);
+ auto reply = invoke_conditional_get();
+
+ ASSERT_EQ(reply->getResult(), api::ReturnCode());
+ EXPECT_EQ(reply->getLastModifiedTimestamp(), timestamp);
+ EXPECT_TRUE(reply->condition_matched());
+ EXPECT_FALSE(reply->is_tombstone());
+ // Checking reply->wasFound() is tempting but doesn't make sense here, as that checks for
+ // the presence of a document object, which metadata-only gets by definition do not return.
+}
+
+TEST_F(TestAndSetTest, conditional_get_returns_doc_metadata_on_mismatch) {
+ const api::Timestamp timestamp = 12345;
+ putTestDocument(false, timestamp);
+ auto reply = invoke_conditional_get();
+
+ ASSERT_EQ(reply->getResult(), api::ReturnCode());
+ EXPECT_EQ(reply->getLastModifiedTimestamp(), timestamp);
+ EXPECT_FALSE(reply->condition_matched());
+ EXPECT_FALSE(reply->is_tombstone());
+}
+
+TEST_F(TestAndSetTest, conditional_get_for_non_existing_document_returns_zero_timestamp) {
+ auto reply = invoke_conditional_get();
+
+ ASSERT_EQ(reply->getResult(), api::ReturnCode());
+ EXPECT_EQ(reply->getLastModifiedTimestamp(), 0);
+ EXPECT_FALSE(reply->condition_matched());
+ EXPECT_FALSE(reply->is_tombstone());
+}
+
+TEST_F(TestAndSetTest, conditional_get_for_non_existing_document_with_explicit_tombstone_returns_tombstone_timestamp) {
+ api::Timestamp timestamp = 56789;
+ feed_remove_entry_with_timestamp(timestamp);
+ auto reply = invoke_conditional_get();
+
+ ASSERT_EQ(reply->getResult(), api::ReturnCode());
+ EXPECT_EQ(reply->getLastModifiedTimestamp(), timestamp);
+ EXPECT_FALSE(reply->condition_matched());
+ EXPECT_TRUE(reply->is_tombstone());
+}
+
+TEST_F(TestAndSetTest, conditional_get_requires_metadata_only_fieldset) {
+ auto get = std::make_shared<api::GetCommand>(BUCKET, testDocId, document::AllFields::NAME);
+ get->set_condition(MATCHING_CONDITION);
+ // Note: uses fetchResult instead of fetch_single_reply due to implicit failure signalling via tracker instance.
+ auto result = fetchResult(simple_handler->handleGet(*get, createTracker(get, BUCKET)));
+ ASSERT_EQ(result, api::ReturnCode(api::ReturnCode::ILLEGAL_PARAMETERS,
+ "Conditional Get operations must be metadata-only"));
+}
+
document::Document::SP
TestAndSetTest::createTestDocument()
{
@@ -270,7 +330,7 @@ TestAndSetTest::retrieveTestDocument()
auto tracker = _persistenceHandler->simpleMessageHandler().handleGet(*get, createTracker(get, BUCKET));
assert(tracker->getResult() == api::ReturnCode::Result::OK);
- auto & reply = static_cast<api::GetReply &>(tracker->getReply());
+ auto& reply = dynamic_cast<api::GetReply&>(tracker->getReply());
assert(reply.wasFound());
return reply.getDocument();
@@ -278,7 +338,7 @@ TestAndSetTest::retrieveTestDocument()
void TestAndSetTest::setTestCondition(api::TestAndSetCommand & command)
{
- command.setCondition(documentapi::TestAndSetCondition("testdoctype1.hstringval=\"*woofy dog*\""));
+ command.setCondition(MATCHING_CONDITION);
}
void TestAndSetTest::putTestDocument(bool matchingHeader, api::Timestamp timestamp) {
@@ -290,6 +350,17 @@ void TestAndSetTest::putTestDocument(bool matchingHeader, api::Timestamp timesta
fetchResult(asyncHandler->handlePut(*put, createTracker(put, BUCKET)));
}
+std::shared_ptr<api::GetReply> TestAndSetTest::invoke_conditional_get() {
+ auto get = std::make_shared<api::GetCommand>(BUCKET, testDocId, document::NoFields::NAME);
+ get->set_condition(MATCHING_CONDITION);
+ return fetch_single_reply<api::GetReply>(simple_handler->handleGet(*get, createTracker(get, BUCKET)));
+}
+
+void TestAndSetTest::feed_remove_entry_with_timestamp(api::Timestamp timestamp) {
+ auto remove = std::make_shared<api::RemoveCommand>(BUCKET, testDocId, timestamp);
+ (void)fetchResult(asyncHandler->handleRemove(*remove, createTracker(remove, BUCKET)));
+}
+
void TestAndSetTest::assertTestDocumentFoundAndMatchesContent(const document::FieldValue & value)
{
auto doc = retrieveTestDocument();
diff --git a/storage/src/tests/storageapi/mbusprot/storageprotocoltest.cpp b/storage/src/tests/storageapi/mbusprot/storageprotocoltest.cpp
index d3036a2fad3..6d8c3585726 100644
--- a/storage/src/tests/storageapi/mbusprot/storageprotocoltest.cpp
+++ b/storage/src/tests/storageapi/mbusprot/storageprotocoltest.cpp
@@ -848,7 +848,7 @@ TEST_P(StorageProtocolTest, track_memory_footprint_for_some_messages) {
EXPECT_EQ(144u + sizeof(vespalib::string), sizeof(PutCommand));
EXPECT_EQ(144u + sizeof(vespalib::string), sizeof(UpdateCommand));
EXPECT_EQ(224u + sizeof(vespalib::string), sizeof(RemoveCommand));
- EXPECT_EQ(296u, sizeof(GetCommand));
+ EXPECT_EQ(296u + sizeof(documentapi::TestAndSetCondition), sizeof(GetCommand));
}
} // storage::api
diff --git a/storage/src/vespa/storage/persistence/asynchandler.cpp b/storage/src/vespa/storage/persistence/asynchandler.cpp
index e20c0475556..60c6d507416 100644
--- a/storage/src/vespa/storage/persistence/asynchandler.cpp
+++ b/storage/src/vespa/storage/persistence/asynchandler.cpp
@@ -358,7 +358,11 @@ bool
AsyncHandler::tasConditionMatches(const api::TestAndSetCommand & cmd, MessageTracker & tracker,
spi::Context & context, bool missingDocumentImpliesMatch) const {
try {
- TestAndSetHelper helper(_env, _spi, _bucketIdFactory, cmd, missingDocumentImpliesMatch);
+ TestAndSetHelper helper(_env, _spi, _bucketIdFactory,
+ cmd.getCondition(),
+ cmd.getBucket(), cmd.getDocumentId(),
+ cmd.getDocumentType(),
+ missingDocumentImpliesMatch);
auto code = helper.retrieveAndMatch(context);
if (code.failed()) {
diff --git a/storage/src/vespa/storage/persistence/persistencehandler.cpp b/storage/src/vespa/storage/persistence/persistencehandler.cpp
index 8d71cc9308b..69f910d0910 100644
--- a/storage/src/vespa/storage/persistence/persistencehandler.cpp
+++ b/storage/src/vespa/storage/persistence/persistencehandler.cpp
@@ -24,7 +24,7 @@ PersistenceHandler::PersistenceHandler(vespalib::ISequencedTaskExecutor & sequen
cfg.commonMergeChainOptimalizationMinimumSize),
_asyncHandler(_env, provider, bucketOwnershipNotifier, sequencedExecutor, component.getBucketIdFactory()),
_splitJoinHandler(_env, provider, bucketOwnershipNotifier, cfg.enableMultibitSplitOptimalization),
- _simpleHandler(_env, provider)
+ _simpleHandler(_env, provider, component.getBucketIdFactory())
{
}
diff --git a/storage/src/vespa/storage/persistence/simplemessagehandler.cpp b/storage/src/vespa/storage/persistence/simplemessagehandler.cpp
index e83d460f47a..ea929bf8620 100644
--- a/storage/src/vespa/storage/persistence/simplemessagehandler.cpp
+++ b/storage/src/vespa/storage/persistence/simplemessagehandler.cpp
@@ -2,6 +2,7 @@
#include "simplemessagehandler.h"
#include "persistenceutil.h"
+#include "testandsethelper.h"
#include <vespa/persistence/spi/persistenceprovider.h>
#include <vespa/persistence/spi/docentry.h>
#include <vespa/storageapi/message/bucket.h>
@@ -45,21 +46,45 @@ getFieldSet(const document::FieldSetRepo & repo, vespalib::stringref name, Messa
}
}
-SimpleMessageHandler::SimpleMessageHandler(const PersistenceUtil& env, spi::PersistenceProvider& spi)
+SimpleMessageHandler::SimpleMessageHandler(const PersistenceUtil& env,
+ spi::PersistenceProvider& spi,
+ const document::BucketIdFactory& bucket_id_factory)
: _env(env),
- _spi(spi)
+ _spi(spi),
+ _bucket_id_factory(bucket_id_factory)
{
}
MessageTracker::UP
+SimpleMessageHandler::handle_conditional_get(api::GetCommand& cmd, MessageTracker::UP tracker) const
+{
+ if (cmd.getFieldSet() == document::NoFields::NAME) {
+ TestAndSetHelper tas_helper(_env, _spi, _bucket_id_factory, cmd.condition(),
+ cmd.getBucket(), cmd.getDocumentId(), nullptr);
+ auto result = tas_helper.fetch_and_match_raw(tracker->context());
+ tracker->setReply(std::make_shared<api::GetReply>(cmd, nullptr, result.timestamp, false,
+ result.is_tombstone(), result.is_match()));
+ } else {
+ tracker->fail(api::ReturnCode::ILLEGAL_PARAMETERS, "Conditional Get operations must be metadata-only");
+ }
+ return tracker;
+}
+
+MessageTracker::UP
SimpleMessageHandler::handleGet(api::GetCommand& cmd, MessageTracker::UP tracker) const
{
auto& metrics = _env._metrics.get;
tracker->setMetric(metrics);
metrics.request_size.addValue(cmd.getApproxByteSize());
+ if (cmd.has_condition()) {
+ return handle_conditional_get(cmd, std::move(tracker));
+ }
+
auto fieldSet = getFieldSet(_env.getFieldSetRepo(), cmd.getFieldSet(), *tracker);
- if ( ! fieldSet) { return tracker; }
+ if (!fieldSet) {
+ return tracker;
+ }
tracker->context().setReadConsistency(api_read_consistency_to_spi(cmd.internal_read_consistency()));
spi::GetResult result = _spi.get(_env.getBucket(cmd.getDocumentId(), cmd.getBucket()),
@@ -70,7 +95,7 @@ SimpleMessageHandler::handleGet(api::GetCommand& cmd, MessageTracker::UP tracker
metrics.notFound.inc();
}
tracker->setReply(std::make_shared<api::GetReply>(cmd, result.getDocumentPtr(), result.getTimestamp(),
- false, result.is_tombstone()));
+ false, result.is_tombstone(), false));
}
return tracker;
diff --git a/storage/src/vespa/storage/persistence/simplemessagehandler.h b/storage/src/vespa/storage/persistence/simplemessagehandler.h
index 009fd6dff52..a5a19772556 100644
--- a/storage/src/vespa/storage/persistence/simplemessagehandler.h
+++ b/storage/src/vespa/storage/persistence/simplemessagehandler.h
@@ -7,6 +7,8 @@
#include <vespa/storage/common/bucketmessages.h>
#include <vespa/storageapi/message/persistence.h>
+namespace document { class BucketIdFactory; }
+
namespace storage {
namespace spi { struct PersistenceProvider; }
@@ -19,7 +21,9 @@ class PersistenceUtil;
*/
class SimpleMessageHandler : public Types {
public:
- SimpleMessageHandler(const PersistenceUtil&, spi::PersistenceProvider&);
+ SimpleMessageHandler(const PersistenceUtil&,
+ spi::PersistenceProvider&,
+ const document::BucketIdFactory&);
MessageTrackerUP handleGet(api::GetCommand& cmd, MessageTrackerUP tracker) const;
MessageTrackerUP handleRevert(api::RevertCommand& cmd, MessageTrackerUP tracker) const;
MessageTrackerUP handleCreateIterator(CreateIteratorCommand& cmd, MessageTrackerUP tracker) const;
@@ -27,8 +31,11 @@ public:
MessageTrackerUP handleReadBucketList(ReadBucketList& cmd, MessageTrackerUP tracker) const;
MessageTrackerUP handleReadBucketInfo(ReadBucketInfo& cmd, MessageTrackerUP tracker) const;
private:
- const PersistenceUtil & _env;
- spi::PersistenceProvider & _spi;
+ MessageTrackerUP handle_conditional_get(api::GetCommand& cmd, MessageTrackerUP tracker) const;
+
+ const PersistenceUtil& _env;
+ spi::PersistenceProvider& _spi;
+ const document::BucketIdFactory& _bucket_id_factory;
};
} // storage
diff --git a/storage/src/vespa/storage/persistence/testandsethelper.cpp b/storage/src/vespa/storage/persistence/testandsethelper.cpp
index 393dac09f72..1cda9427761 100644
--- a/storage/src/vespa/storage/persistence/testandsethelper.cpp
+++ b/storage/src/vespa/storage/persistence/testandsethelper.cpp
@@ -31,69 +31,91 @@ void TestAndSetHelper::parseDocumentSelection(const document::DocumentTypeRepo &
document::select::Parser parser(documentTypeRepo, bucketIdFactory);
try {
- _docSelectionUp = parser.parse(_cmd.getCondition().getSelection());
+ _docSelectionUp = parser.parse(_condition.getSelection());
} catch (const document::select::ParsingFailedException & e) {
throw TestAndSetException(api::ReturnCode(api::ReturnCode::ILLEGAL_PARAMETERS, "Failed to parse test and set condition: "s + e.getMessage()));
}
}
spi::GetResult TestAndSetHelper::retrieveDocument(const document::FieldSet & fieldSet, spi::Context & context) {
- return _spi.get(_env.getBucket(_docId, _cmd.getBucket()), fieldSet, _cmd.getDocumentId(), context);
+ return _spi.get(_env.getBucket(_docId, _bucket), fieldSet, _docId, context);
}
-TestAndSetHelper::TestAndSetHelper(const PersistenceUtil & env, const spi::PersistenceProvider & spi,
- const document::BucketIdFactory & bucketFactory,
- const api::TestAndSetCommand & cmd, bool missingDocumentImpliesMatch)
+TestAndSetHelper::TestAndSetHelper(const PersistenceUtil& env,
+ const spi::PersistenceProvider& spi,
+ const document::BucketIdFactory& bucket_id_factory,
+ const documentapi::TestAndSetCondition& condition,
+ document::Bucket bucket,
+ document::DocumentId doc_id,
+ const document::DocumentType* doc_type_ptr,
+ bool missingDocumentImpliesMatch)
: _env(env),
_spi(spi),
- _cmd(cmd),
- _docId(cmd.getDocumentId()),
- _docTypePtr(_cmd.getDocumentType()),
+ _condition(condition),
+ _bucket(bucket),
+ _docId(std::move(doc_id)),
+ _docTypePtr(doc_type_ptr),
_missingDocumentImpliesMatch(missingDocumentImpliesMatch)
{
const auto & repo = _env.getDocumentTypeRepo();
resolveDocumentType(repo);
- parseDocumentSelection(repo, bucketFactory);
+ parseDocumentSelection(repo, bucket_id_factory);
}
TestAndSetHelper::~TestAndSetHelper() = default;
-api::ReturnCode
-TestAndSetHelper::retrieveAndMatch(spi::Context & context) {
- // Walk document selection tree to build a minimal field set
+TestAndSetHelper::Result
+TestAndSetHelper::fetch_and_match_raw(spi::Context& context) {
+ // Walk document selection tree to build a minimal field set
FieldVisitor fieldVisitor(*_docTypePtr);
try {
_docSelectionUp->visit(fieldVisitor);
} catch (const document::FieldNotFoundException& e) {
- return api::ReturnCode(api::ReturnCode::ILLEGAL_PARAMETERS,
- vespalib::make_string("Condition field '%s' could not be found, or is an imported field. "
- "Imported fields are not supported in conditional mutations.",
- e.getFieldName().c_str()));
+ throw TestAndSetException(api::ReturnCode(
+ api::ReturnCode::ILLEGAL_PARAMETERS,
+ vespalib::make_string("Condition field '%s' could not be found, or is an imported field. "
+ "Imported fields are not supported in conditional mutations.",
+ e.getFieldName().c_str())));
}
-
- // Retrieve document
auto result = retrieveDocument(fieldVisitor.getFieldSet(), context);
-
// If document exists, match it with selection
if (result.hasDocument()) {
auto docPtr = result.getDocumentPtr();
if (_docSelectionUp->contains(*docPtr) != document::select::Result::True) {
- return api::ReturnCode(api::ReturnCode::TEST_AND_SET_CONDITION_FAILED,
- vespalib::make_string("Condition did not match document nodeIndex=%d bucket=%" PRIx64 " %s",
- _env._nodeIndex, _cmd.getBucketId().getRawId(),
- _cmd.hasBeenRemapped() ? "remapped" : ""));
+ return {result.getTimestamp(), Result::ConditionOutcome::IsNotMatch};
}
-
// Document matches
- return api::ReturnCode();
- } else if (_missingDocumentImpliesMatch) {
- return api::ReturnCode();
+ return {result.getTimestamp(), Result::ConditionOutcome::IsMatch};
}
+ return {result.getTimestamp(), result.is_tombstone() ? Result::ConditionOutcome::IsTombstone
+ : Result::ConditionOutcome::DocNotFound};
+}
- return api::ReturnCode(api::ReturnCode::TEST_AND_SET_CONDITION_FAILED,
- vespalib::make_string("Document does not exist nodeIndex=%d bucket=%" PRIx64 " %s",
- _env._nodeIndex, _cmd.getBucketId().getRawId(),
- _cmd.hasBeenRemapped() ? "remapped" : ""));
+api::ReturnCode
+TestAndSetHelper::to_api_return_code(const Result& result) const {
+ switch (result.condition_outcome) {
+ case Result::ConditionOutcome::IsNotMatch:
+ return {api::ReturnCode::TEST_AND_SET_CONDITION_FAILED,
+ vespalib::make_string("Condition did not match document nodeIndex=%d bucket=%" PRIx64,
+ _env._nodeIndex, _bucket.getBucketId().getRawId())};
+ case Result::ConditionOutcome::IsTombstone:
+ case Result::ConditionOutcome::DocNotFound:
+ if (!_missingDocumentImpliesMatch) {
+ return {api::ReturnCode::TEST_AND_SET_CONDITION_FAILED,
+ vespalib::make_string("Document does not exist nodeIndex=%d bucket=%" PRIx64,
+ _env._nodeIndex, _bucket.getBucketId().getRawId())};
+ }
+ [[fallthrough]]; // as match
+ case Result::ConditionOutcome::IsMatch:
+ return {}; // OK
+ }
+ abort();
+}
+
+api::ReturnCode
+TestAndSetHelper::retrieveAndMatch(spi::Context & context) {
+ auto result = fetch_and_match_raw(context);
+ return to_api_return_code(result);
}
} // storage
diff --git a/storage/src/vespa/storage/persistence/testandsethelper.h b/storage/src/vespa/storage/persistence/testandsethelper.h
index 82710e523c4..31b1cc79a54 100644
--- a/storage/src/vespa/storage/persistence/testandsethelper.h
+++ b/storage/src/vespa/storage/persistence/testandsethelper.h
@@ -25,9 +25,8 @@ class PersistenceUtil;
class TestAndSetException : public std::runtime_error {
api::ReturnCode _code;
-
public:
- TestAndSetException(api::ReturnCode code)
+ explicit TestAndSetException(api::ReturnCode code)
: std::runtime_error(code.getMessage()),
_code(std::move(code))
{}
@@ -36,11 +35,12 @@ public:
};
class TestAndSetHelper {
- const PersistenceUtil &_env;
- const spi::PersistenceProvider &_spi;
- const api::TestAndSetCommand &_cmd;
+ const PersistenceUtil& _env;
+ const spi::PersistenceProvider& _spi;
+ const documentapi::TestAndSetCondition& _condition;
+ const document::Bucket _bucket;
const document::DocumentId _docId;
- const document::DocumentType * _docTypePtr;
+ const document::DocumentType* _docTypePtr;
std::unique_ptr<document::select::Node> _docSelectionUp;
bool _missingDocumentImpliesMatch;
@@ -50,10 +50,44 @@ class TestAndSetHelper {
spi::GetResult retrieveDocument(const document::FieldSet & fieldSet, spi::Context & context);
public:
- TestAndSetHelper(const PersistenceUtil & env, const spi::PersistenceProvider & _spi,
- const document::BucketIdFactory & bucketIdFactory,
- const api::TestAndSetCommand & cmd, bool missingDocumentImpliesMatch = false);
+ struct Result {
+ enum class ConditionOutcome {
+ DocNotFound,
+ IsMatch,
+ IsNotMatch,
+ IsTombstone
+ };
+
+ api::Timestamp timestamp = 0;
+ ConditionOutcome condition_outcome = ConditionOutcome::IsNotMatch;
+
+ [[nodiscard]] bool doc_not_found() const noexcept {
+ return condition_outcome == ConditionOutcome::DocNotFound;
+ }
+ [[nodiscard]] bool is_match() const noexcept {
+ return condition_outcome == ConditionOutcome::IsMatch;
+ }
+ [[nodiscard]] bool is_not_match() const noexcept {
+ return condition_outcome == ConditionOutcome::IsNotMatch;
+ }
+ [[nodiscard]] bool is_tombstone() const noexcept {
+ return condition_outcome == ConditionOutcome::IsTombstone;
+ }
+ };
+
+ TestAndSetHelper(const PersistenceUtil& env,
+ const spi::PersistenceProvider& _spi,
+ const document::BucketIdFactory& bucket_id_factory,
+ const documentapi::TestAndSetCondition& condition,
+ document::Bucket bucket,
+ document::DocumentId doc_id,
+ const document::DocumentType* doc_type_ptr,
+ bool missingDocumentImpliesMatch = false);
~TestAndSetHelper();
+
+ Result fetch_and_match_raw(spi::Context& context);
+ api::ReturnCode to_api_return_code(const Result& result) const;
+
api::ReturnCode retrieveAndMatch(spi::Context & context);
};
diff --git a/storage/src/vespa/storageapi/message/persistence.cpp b/storage/src/vespa/storageapi/message/persistence.cpp
index 41a53449b67..1b09639fd9b 100644
--- a/storage/src/vespa/storageapi/message/persistence.cpp
+++ b/storage/src/vespa/storageapi/message/persistence.cpp
@@ -222,7 +222,8 @@ GetReply::GetReply(const GetCommand& cmd,
const DocumentSP& doc,
Timestamp lastModified,
bool had_consistent_replicas,
- bool is_tombstone)
+ bool is_tombstone,
+ bool condition_matched)
: BucketInfoReply(cmd),
_docId(cmd.getDocumentId()),
_fieldSet(cmd.getFieldSet()),
@@ -230,7 +231,8 @@ GetReply::GetReply(const GetCommand& cmd,
_beforeTimestamp(cmd.getBeforeTimestamp()),
_lastModifiedTime(lastModified),
_had_consistent_replicas(had_consistent_replicas),
- _is_tombstone(is_tombstone)
+ _is_tombstone(is_tombstone),
+ _condition_matched(condition_matched)
{
}
diff --git a/storage/src/vespa/storageapi/message/persistence.h b/storage/src/vespa/storageapi/message/persistence.h
index d1709c46a6e..d010c295ca7 100644
--- a/storage/src/vespa/storageapi/message/persistence.h
+++ b/storage/src/vespa/storageapi/message/persistence.h
@@ -185,9 +185,10 @@ public:
* timestamp.
*/
class GetCommand : public BucketInfoCommand {
- document::DocumentId _docId;
- Timestamp _beforeTimestamp;
- vespalib::string _fieldSet;
+ document::DocumentId _docId;
+ Timestamp _beforeTimestamp;
+ vespalib::string _fieldSet;
+ TestAndSetCondition _condition;
InternalReadConsistency _internal_read_consistency;
public:
GetCommand(const document::Bucket &bucket, const document::DocumentId&,
@@ -198,6 +199,9 @@ public:
Timestamp getBeforeTimestamp() const { return _beforeTimestamp; }
const vespalib::string& getFieldSet() const { return _fieldSet; }
void setFieldSet(vespalib::stringref fieldSet) { _fieldSet = fieldSet; }
+ [[nodiscard]] bool has_condition() const noexcept { return _condition.isPresent(); }
+ [[nodiscard]] const TestAndSetCondition& condition() const noexcept { return _condition; }
+ void set_condition(TestAndSetCondition cond) { _condition = std::move(cond); }
InternalReadConsistency internal_read_consistency() const noexcept {
return _internal_read_consistency;
}
@@ -229,12 +233,14 @@ class GetReply : public BucketInfoReply {
Timestamp _lastModifiedTime;
bool _had_consistent_replicas;
bool _is_tombstone;
+ bool _condition_matched;
public:
explicit GetReply(const GetCommand& cmd,
const DocumentSP& doc = DocumentSP(),
Timestamp lastModified = 0,
bool had_consistent_replicas = false,
- bool is_tombstone = false);
+ bool is_tombstone = false,
+ bool condition_matched = false);
~GetReply() override;
@@ -247,6 +253,7 @@ public:
[[nodiscard]] bool had_consistent_replicas() const noexcept { return _had_consistent_replicas; }
[[nodiscard]] bool is_tombstone() const noexcept { return _is_tombstone; }
+ [[nodiscard]] bool condition_matched() const noexcept { return _condition_matched; }
bool wasFound() const { return (_doc.get() != nullptr); }
void print(std::ostream& out, bool verbose, const std::string& indent) const override;
diff --git a/vdslib/src/main/java/com/yahoo/vdslib/distribution/Distribution.java b/vdslib/src/main/java/com/yahoo/vdslib/distribution/Distribution.java
index 9a451ac56ec..a83e2a4f89c 100644
--- a/vdslib/src/main/java/com/yahoo/vdslib/distribution/Distribution.java
+++ b/vdslib/src/main/java/com/yahoo/vdslib/distribution/Distribution.java
@@ -187,11 +187,11 @@ public class Distribution {
}
private int getStorageSeed(BucketId bucket, ClusterState state) {
- int seed = (int) lastNBits(bucket.getRawId(), state.getDistributionBitCount());
+ int seed = (int)lastNBits(bucket.getRawId(), state.getDistributionBitCount());
if (bucket.getUsedBits() > 33) {
int usedBits = bucket.getUsedBits() - 1;
- seed ^= lastNBits(bucket.getRawId() >> 32, usedBits - 32) << 6;
+ seed ^= (int)lastNBits(bucket.getRawId() >> 32, usedBits - 32) << 6;
}
return seed;
}
diff --git a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt
index b5841d1c9e4..0d007097fa2 100644
--- a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt
+++ b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt
@@ -34,12 +34,20 @@ com.google.protobuf:protobuf-java:3.21.7
com.ibm.icu:icu4j:70.1
com.intellij:annotations:9.0.4
com.microsoft.onnxruntime:onnxruntime:1.13.1
+com.squareup.okhttp3:okhttp:3.14.9
+com.squareup.okio:okio:1.17.2
+com.squareup.retrofit2:adapter-rxjava2:2.9.0
+com.squareup.retrofit2:converter-jackson:2.9.0
+com.squareup.retrofit2:retrofit:2.9.0
com.sun.activation:javax.activation:1.2.0
com.sun.istack:istack-commons-runtime:3.0.8
com.sun.xml.bind:jaxb-core:2.3.0
com.sun.xml.bind:jaxb-impl:2.3.0
com.sun.xml.fastinfoset:FastInfoset:1.2.16
com.thaiopensource:jing:20091111
+com.theokanning.openai-gpt3-java:api:0.12.0
+com.theokanning.openai-gpt3-java:client:0.12.0
+com.theokanning.openai-gpt3-java:service:0.12.0
com.yahoo.athenz:athenz-auth-core:1.10.54
com.yahoo.athenz:athenz-client-common:1.10.54
com.yahoo.athenz:athenz-zms-core:1.10.54
@@ -69,6 +77,7 @@ io.netty:netty-transport-native-epoll:4.1.86.Final
io.netty:netty-transport-native-unix-common:4.1.86.Final
io.prometheus:simpleclient:0.6.0
io.prometheus:simpleclient_common:0.6.0
+io.reactivex.rxjava2:rxjava:2.0.0
javax.annotation:javax.annotation-api:1.2
javax.inject:javax.inject:1
javax.servlet:javax.servlet-api:3.1.0
@@ -201,6 +210,7 @@ org.ow2.asm:asm-commons:9.3
org.ow2.asm:asm-tree:9.3
org.ow2.asm:asm-util:9.3
org.questdb:questdb:6.2
+org.reactivestreams:reactive-streams:1.0.3
org.slf4j:jcl-over-slf4j:1.7.32
org.slf4j:log4j-over-slf4j:1.7.32
org.slf4j:slf4j-api:1.7.32
diff --git a/vespa-feed-client-api/pom.xml b/vespa-feed-client-api/pom.xml
index 5509c339eee..0782bb6b28e 100644
--- a/vespa-feed-client-api/pom.xml
+++ b/vespa-feed-client-api/pom.xml
@@ -42,6 +42,12 @@
<configuration>
<release>${vespaClients.jdk.releaseVersion}</release>
<showDeprecation>true</showDeprecation>
+ <compilerArgs> <!-- Remove (to use default) when not compiling for 8 -->
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-rawtypes</arg>
+ <arg>-Xlint:-unchecked</arg>
+ <arg>-Xlint:-serial</arg>
+ </compilerArgs>
</configuration>
</plugin>
<plugin>
diff --git a/vespa-feed-client-cli/pom.xml b/vespa-feed-client-cli/pom.xml
index 46679906fc4..b917a39b675 100644
--- a/vespa-feed-client-cli/pom.xml
+++ b/vespa-feed-client-cli/pom.xml
@@ -53,6 +53,12 @@
<configuration>
<release>${vespaClients.jdk.releaseVersion}</release>
<showDeprecation>true</showDeprecation>
+ <compilerArgs> <!-- Remove (to use default) when not compiling for 8 -->
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-rawtypes</arg>
+ <arg>-Xlint:-unchecked</arg>
+ <arg>-Xlint:-serial</arg>
+ </compilerArgs>
</configuration>
</plugin>
<plugin>
diff --git a/vespa-feed-client/pom.xml b/vespa-feed-client/pom.xml
index 01b9b00b8a0..b6440653a78 100644
--- a/vespa-feed-client/pom.xml
+++ b/vespa-feed-client/pom.xml
@@ -65,6 +65,12 @@
<goal>compile</goal>
</goals>
<configuration>
+ <compilerArgs> <!-- Remove (to use default) when not compiling for 8 -->
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-rawtypes</arg>
+ <arg>-Xlint:-unchecked</arg>
+ <arg>-Xlint:-serial</arg>
+ </compilerArgs>
<release>${vespaClients.jdk.releaseVersion}</release>
<showDeprecation>true</showDeprecation>
</configuration>
diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java
index 2a688ad078b..5126f7e0f43 100644
--- a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java
+++ b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java
@@ -526,23 +526,18 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler {
parameters = getProperty(request, TIMEOUT, timeoutMillisParser).map(clock.instant()::plusMillis)
.map(parameters::withDeadline)
.orElse(parameters);
- for (String name : names) switch (name) {
- case CLUSTER:
- parameters = getProperty(request, CLUSTER).map(cluster -> resolveCluster(Optional.of(cluster), clusters).name())
- .map(parameters::withRoute)
- .orElse(parameters);
- break;
- case FIELD_SET:
- parameters = getProperty(request, FIELD_SET).map(parameters::withFieldSet)
- .orElse(parameters);
- break;
- case ROUTE:
- parameters = getProperty(request, ROUTE).map(parameters::withRoute)
- .orElse(parameters);
- break;
- default:
- throw new IllegalArgumentException("Unrecognized document operation parameter name '" + name + "'");
- }
+ for (String name : names)
+ parameters = switch (name) {
+ case CLUSTER ->
+ getProperty(request, CLUSTER)
+ .map(cluster -> resolveCluster(Optional.of(cluster), clusters).name())
+ .map(parameters::withRoute)
+ .orElse(parameters);
+ case FIELD_SET -> getProperty(request, FIELD_SET).map(parameters::withFieldSet).orElse(parameters);
+ case ROUTE -> getProperty(request, ROUTE).map(parameters::withRoute).orElse(parameters);
+ default ->
+ throw new IllegalArgumentException("Unrecognized document operation parameter name '" + name + "'");
+ };
return parameters;
}
@@ -630,10 +625,6 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler {
private boolean first = true;
private ContentChannel channel;
- private JsonResponse(ResponseHandler handler) throws IOException {
- this(handler, null);
- }
-
private JsonResponse(ResponseHandler handler, HttpRequest request) throws IOException {
this.handler = handler;
this.request = request;
@@ -642,11 +633,6 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler {
}
/** Creates a new JsonResponse with path and id fields written. */
- static JsonResponse create(DocumentPath path, ResponseHandler handler) throws IOException {
- return create(path, handler, null);
- }
-
- /** Creates a new JsonResponse with path and id fields written. */
static JsonResponse create(DocumentPath path, ResponseHandler handler, HttpRequest request) throws IOException {
JsonResponse response = new JsonResponse(handler, request);
response.writePathId(path.rawPath());
@@ -749,23 +735,17 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler {
}
private boolean tensorShortForm() {
- if (request != null &&
- request.parameters().containsKey("format.tensors") &&
- ( request.parameters().get("format.tensors").contains("long")
- || request.parameters().get("format.tensors").contains("long-value"))) {
- return false;
- }
- return true; // default
+ return request == null ||
+ !request.parameters().containsKey("format.tensors") ||
+ (!request.parameters().get("format.tensors").contains("long")
+ && !request.parameters().get("format.tensors").contains("long-value"));// default
}
private boolean tensorDirectValues() {
- if (request != null &&
- request.parameters().containsKey("format.tensors") &&
- ( request.parameters().get("format.tensors").contains("short-value")
- || request.parameters().get("format.tensors").contains("long-value"))) {
- return true;
- }
- return false; // TODO: Flip default on Vespa 9
+ return request != null &&
+ request.parameters().containsKey("format.tensors") &&
+ (request.parameters().get("format.tensors").contains("short-value")
+ || request.parameters().get("format.tensors").contains("long-value"));// TODO: Flip default on Vespa 9
}
synchronized void writeSingleDocument(Document document) throws IOException {
@@ -1168,9 +1148,8 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler {
// ------------------------------------------------- Visits ------------------------------------------------
private VisitorParameters parseGetParameters(HttpRequest request, DocumentPath path, boolean streamed) {
- int wantedDocumentCount = Math.min(streamed ? Integer.MAX_VALUE : 1 << 10,
- getProperty(request, WANTED_DOCUMENT_COUNT, integerParser)
- .orElse(streamed ? Integer.MAX_VALUE : 1));
+ int wantedDocumentCount = getProperty(request, WANTED_DOCUMENT_COUNT, integerParser)
+ .orElse(streamed ? Integer.MAX_VALUE : 1);
if (wantedDocumentCount <= 0)
throw new IllegalArgumentException("wantedDocumentCount must be positive");
@@ -1546,11 +1525,11 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler {
private static Map<String, StorageCluster> parseClusters(ClusterListConfig clusters, AllClustersBucketSpacesConfig buckets) {
return clusters.storage().stream()
- .collect(toUnmodifiableMap(storage -> storage.name(),
+ .collect(toUnmodifiableMap(ClusterListConfig.Storage::name,
storage -> new StorageCluster(storage.name(),
buckets.cluster(storage.name())
.documentType().entrySet().stream()
- .collect(toMap(entry -> entry.getKey(),
+ .collect(toMap(Map.Entry::getKey,
entry -> entry.getValue().bucketSpace())))));
}
diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java b/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java
index 851a0949266..7696fd2196c 100644
--- a/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java
+++ b/vespaclient-container-plugin/src/test/java/com/yahoo/document/restapi/resource/DocumentV1ApiTest.java
@@ -217,7 +217,7 @@ public class DocumentV1ApiTest {
access.expect(parameters -> {
assertEquals("content", parameters.getRoute().toString());
assertEquals("default", parameters.getBucketSpace());
- assertEquals(1024, parameters.getMaxTotalHits());
+ assertEquals(1025, parameters.getMaxTotalHits());
assertEquals(100, ((StaticThrottlePolicy) parameters.getThrottlePolicy()).getMaxPendingCount());
assertEquals("[id]", parameters.getFieldSet());
assertEquals("(all the things)", parameters.getDocumentSelection());
diff --git a/vespajlib/src/main/java/com/yahoo/slime/BinaryEncoder.java b/vespajlib/src/main/java/com/yahoo/slime/BinaryEncoder.java
index f12496f7a76..e0b4fb2c672 100644
--- a/vespajlib/src/main/java/com/yahoo/slime/BinaryEncoder.java
+++ b/vespajlib/src/main/java/com/yahoo/slime/BinaryEncoder.java
@@ -28,7 +28,7 @@ final class BinaryEncoder implements ArrayTraverser, ObjectSymbolTraverser {
byte next = (byte)(value & 0x7f);
value >>>= 7; // unsigned shift
while (value != 0) {
- next |= 0x80;
+ next |= (byte)0x80;
out.put(next);
next = (byte)(value & 0x7f);
value >>>= 7;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index 11996b6a23d..de1c30e6414 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -333,7 +333,7 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
private final long[] bounds;
private final long[] iterator;
- private int remaining;
+ private long remaining;
MultiDimensionIterator(TensorType type) {
bounds = new long[type.dimensions().size()];
diff --git a/vespajlib/src/test/java/com/yahoo/io/FatalErrorHandlerTestCase.java b/vespajlib/src/test/java/com/yahoo/io/FatalErrorHandlerTestCase.java
deleted file mode 100644
index dab91b6a995..00000000000
--- a/vespajlib/src/test/java/com/yahoo/io/FatalErrorHandlerTestCase.java
+++ /dev/null
@@ -1,58 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.io;
-
-import static org.junit.Assert.*;
-
-import java.security.Permission;
-
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-
-/**
- * Just to remove noise from the coverage report.
- *
- * @author Steinar Knutsen
- */
-public class FatalErrorHandlerTestCase {
- @SuppressWarnings("removal")
- private static final class AvoidExiting extends SecurityManager {
-
- @Override
- public void checkPermission(Permission perm) {
- }
-
- @Override
- public void checkExit(int status) {
- throw new SecurityException();
- }
-
- }
-
- private FatalErrorHandler h;
-
- @Before
- @SuppressWarnings("removal")
- public void setUp() throws Exception {
- h = new FatalErrorHandler();
- System.setSecurityManager(new AvoidExiting());
- }
-
- @After
- @SuppressWarnings("removal")
- public void tearDown() throws Exception {
- System.setSecurityManager(null);
- }
-
- @Test
- public final void testHandle() {
- boolean caught = false;
- try {
- h.handle(new Throwable(), "abc");
- } catch (SecurityException e) {
- caught = true;
- }
- assertTrue(caught);
- }
-
-}