diff options
40 files changed, 938 insertions, 686 deletions
diff --git a/client/go/vespa/target.go b/client/go/vespa/target.go index 2314cc06f33..8d48f1ffa3a 100644 --- a/client/go/vespa/target.go +++ b/client/go/vespa/target.go @@ -3,23 +3,15 @@ package vespa import ( - "bytes" "crypto/tls" - "encoding/json" "fmt" "io" "io/ioutil" - "math" "net/http" - "net/url" - "sort" - "strconv" "time" - "github.com/vespa-engine/vespa/client/go/auth0" "github.com/vespa-engine/vespa/client/go/util" "github.com/vespa-engine/vespa/client/go/version" - "github.com/vespa-engine/vespa/client/go/zts" ) const ( @@ -94,26 +86,6 @@ type LogOptions struct { Level int } -// CloudOptions configures URL and authentication for a cloud target. -type APIOptions struct { - System System - TLSOptions TLSOptions - APIKey []byte - AuthConfigPath string -} - -// CloudDeploymentOptions configures the deployment to manage through a cloud target. -type CloudDeploymentOptions struct { - Deployment Deployment - TLSOptions TLSOptions - ClusterURLs map[string]string // Endpoints keyed on cluster name -} - -type customTarget struct { - targetType string - baseURL string -} - // Do sends request to this service. Any required authentication happens automatically. func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Response, error) { if s.TLSOptions.KeyPair.Certificate != nil { @@ -143,12 +115,7 @@ func (s *Service) Wait(timeout time.Duration) (int, error) { default: return 0, fmt.Errorf("invalid service: %s", s.Name) } - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return 0, err - } - okFunc := func(status int, response []byte) (bool, error) { return status/100 == 2, nil } - return wait(okFunc, func() *http.Request { return req }, &s.TLSOptions.KeyPair, timeout) + return waitForOK(url, &s.TLSOptions.KeyPair, timeout) } func (s *Service) Description() string { @@ -163,442 +130,23 @@ func (s *Service) Description() string { return fmt.Sprintf("No description of service %s", s.Name) } -func (t *customTarget) Type() string { return t.targetType } - -func (t *customTarget) Deployment() Deployment { return Deployment{} } - -func (t *customTarget) Service(name string, timeout time.Duration, sessionOrRunID int64, cluster string) (*Service, error) { - if timeout > 0 && name != DeployService { - if err := t.waitForConvergence(timeout); err != nil { - return nil, err - } - } - switch name { - case DeployService, QueryService, DocumentService: - url, err := t.urlWithPort(name) - if err != nil { - return nil, err - } - return &Service{BaseURL: url, Name: name}, nil - } - return nil, fmt.Errorf("unknown service: %s", name) -} - -func (t *customTarget) PrintLog(options LogOptions) error { - return fmt.Errorf("reading logs from non-cloud deployment is unsupported") -} - -func (t *customTarget) SignRequest(req *http.Request, sigKeyId string) error { return nil } +func isOK(status int) bool { return status/100 == 2 } -func (t *customTarget) CheckVersion(version version.Version) error { return nil } +type responseFunc func(status int, response []byte) (bool, error) -func (t *customTarget) urlWithPort(serviceName string) (string, error) { - u, err := url.Parse(t.baseURL) - if err != nil { - return "", err - } - port := u.Port() - if port == "" { - switch serviceName { - case DeployService: - port = "19071" - case QueryService, DocumentService: - port = "8080" - default: - return "", fmt.Errorf("unknown service: %s", serviceName) - } - u.Host = u.Host + ":" + port - } - return u.String(), nil -} +type requestFunc func() *http.Request -func (t *customTarget) waitForConvergence(timeout time.Duration) error { - deployer, err := t.Service(DeployService, 0, 0, "") - if err != nil { - return err - } - url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployer.BaseURL) +// waitForOK queries url and returns its status code. If the url returns a non-200 status code, it is repeatedly queried +// until timeout elapses. +func waitForOK(url string, certificate *tls.Certificate, timeout time.Duration) (int, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { - return err - } - converged := false - convergedFunc := func(status int, response []byte) (bool, error) { - if status/100 != 2 { - return false, nil - } - var resp serviceConvergeResponse - if err := json.Unmarshal(response, &resp); err != nil { - return false, nil - } - converged = resp.Converged - return converged, nil - } - if _, err := wait(convergedFunc, func() *http.Request { return req }, nil, timeout); err != nil { - return err - } - if !converged { - return fmt.Errorf("services have not converged") - } - return nil -} - -type cloudTarget struct { - apiOptions APIOptions - deploymentOptions CloudDeploymentOptions - logOptions LogOptions - ztsClient ztsClient -} - -type ztsClient interface { - AccessToken(domain string, certficiate tls.Certificate) (string, error) -} - -func (t *cloudTarget) resolveEndpoint(cluster string) (string, error) { - if cluster == "" { - for _, u := range t.deploymentOptions.ClusterURLs { - if len(t.deploymentOptions.ClusterURLs) == 1 { - return u, nil - } else { - return "", fmt.Errorf("multiple clusters, none chosen: %v", t.deploymentOptions.ClusterURLs) - } - } - } else { - u := t.deploymentOptions.ClusterURLs[cluster] - if u == "" { - clusters := make([]string, len(t.deploymentOptions.ClusterURLs)) - for c := range t.deploymentOptions.ClusterURLs { - clusters = append(clusters, c) - } - return "", fmt.Errorf("unknown cluster '%s': must be one of %v", cluster, clusters) - } - return u, nil - } - - return "", fmt.Errorf("no endpoints") -} - -func (t *cloudTarget) Type() string { - switch t.apiOptions.System.Name { - case MainSystem.Name, CDSystem.Name: - return TargetHosted - } - return TargetCloud -} - -func (t *cloudTarget) Deployment() Deployment { return t.deploymentOptions.Deployment } - -func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, cluster string) (*Service, error) { - if name != DeployService && t.deploymentOptions.ClusterURLs == nil { - if err := t.waitForEndpoints(timeout, runID); err != nil { - return nil, err - } - } - switch name { - case DeployService: - return &Service{Name: name, BaseURL: t.apiOptions.System.URL, TLSOptions: t.apiOptions.TLSOptions, ztsClient: t.ztsClient}, nil - case QueryService, DocumentService: - url, err := t.resolveEndpoint(cluster) - if err != nil { - return nil, err - } - t.deploymentOptions.TLSOptions.AthenzDomain = t.apiOptions.System.AthenzDomain - return &Service{Name: name, BaseURL: url, TLSOptions: t.deploymentOptions.TLSOptions, ztsClient: t.ztsClient}, nil - } - return nil, fmt.Errorf("unknown service: %s", name) -} - -func (t *cloudTarget) SignRequest(req *http.Request, keyID string) error { - if t.apiOptions.System.IsPublic() { - if t.apiOptions.APIKey != nil { - signer := NewRequestSigner(keyID, t.apiOptions.APIKey) - return signer.SignRequest(req) - } else { - return t.addAuth0AccessToken(req) - } - } else { - if t.apiOptions.TLSOptions.KeyPair.Certificate == nil { - return fmt.Errorf("system %s requires a certificate for authentication", t.apiOptions.System.Name) - } - return nil - } -} - -func (t *cloudTarget) CheckVersion(clientVersion version.Version) error { - if clientVersion.IsZero() { // development version is always fine - return nil - } - req, err := http.NewRequest("GET", fmt.Sprintf("%s/cli/v1/", t.apiOptions.System.URL), nil) - if err != nil { - return err - } - response, err := util.HttpDo(req, 10*time.Second, "") - if err != nil { - return err - } - defer response.Body.Close() - var cliResponse struct { - MinVersion string `json:"minVersion"` - } - dec := json.NewDecoder(response.Body) - if err := dec.Decode(&cliResponse); err != nil { - return err - } - minVersion, err := version.Parse(cliResponse.MinVersion) - if err != nil { - return err - } - if clientVersion.Less(minVersion) { - return fmt.Errorf("client version %s is less than the minimum supported version: %s", clientVersion, minVersion) - } - return nil -} - -func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error { - a, err := auth0.GetAuth0(t.apiOptions.AuthConfigPath, t.apiOptions.System.Name, t.apiOptions.System.URL) - if err != nil { - return err - } - system, err := a.PrepareSystem(auth0.ContextWithCancel()) - if err != nil { - return err - } - request.Header.Set("Authorization", "Bearer "+system.AccessToken) - return nil -} - -func (t *cloudTarget) logsURL() string { - return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs", - t.apiOptions.System.URL, - t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, - t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region) -} - -func (t *cloudTarget) PrintLog(options LogOptions) error { - req, err := http.NewRequest("GET", t.logsURL(), nil) - if err != nil { - return err - } - lastFrom := options.From - requestFunc := func() *http.Request { - fromMillis := lastFrom.Unix() * 1000 - q := req.URL.Query() - q.Set("from", strconv.FormatInt(fromMillis, 10)) - if !options.To.IsZero() { - toMillis := options.To.Unix() * 1000 - q.Set("to", strconv.FormatInt(toMillis, 10)) - } - req.URL.RawQuery = q.Encode() - t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()) - return req - } - logFunc := func(status int, response []byte) (bool, error) { - if ok, err := isOK(status); !ok { - return ok, err - } - logEntries, err := ReadLogEntries(bytes.NewReader(response)) - if err != nil { - return true, err - } - for _, le := range logEntries { - if !le.Time.After(lastFrom) { - continue - } - if LogLevel(le.Level) > options.Level { - continue - } - fmt.Fprintln(options.Writer, le.Format(options.Dequote)) - } - if len(logEntries) > 0 { - lastFrom = logEntries[len(logEntries)-1].Time - } - return false, nil - } - var timeout time.Duration - if options.Follow { - timeout = math.MaxInt64 // No timeout - } - _, err = wait(logFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout) - return err -} - -func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error { - if runID > 0 { - if err := t.waitForRun(runID, timeout); err != nil { - return err - } - } - return t.discoverEndpoints(timeout) -} - -func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { - runURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/job/%s-%s/run/%d", - t.apiOptions.System.URL, - t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, - t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region, runID) - req, err := http.NewRequest("GET", runURL, nil) - if err != nil { - return err - } - lastID := int64(-1) - requestFunc := func() *http.Request { - q := req.URL.Query() - q.Set("after", strconv.FormatInt(lastID, 10)) - req.URL.RawQuery = q.Encode() - if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { - panic(err) - } - return req - } - jobSuccessFunc := func(status int, response []byte) (bool, error) { - if ok, err := isOK(status); !ok { - return ok, err - } - var resp jobResponse - if err := json.Unmarshal(response, &resp); err != nil { - return false, nil - } - if t.logOptions.Writer != nil { - lastID = t.printLog(resp, lastID) - } - if resp.Active { - return false, nil - } - if resp.Status != "success" { - return false, fmt.Errorf("run %d ended with unsuccessful status: %s", runID, resp.Status) - } - return true, nil - } - _, err = wait(jobSuccessFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout) - return err -} - -func (t *cloudTarget) printLog(response jobResponse, last int64) int64 { - if response.LastID == 0 { - return last - } - var msgs []logMessage - for step, stepMsgs := range response.Log { - for _, msg := range stepMsgs { - if step == "copyVespaLogs" && LogLevel(msg.Type) > t.logOptions.Level || LogLevel(msg.Type) == 3 { - continue - } - msgs = append(msgs, msg) - } - } - sort.Slice(msgs, func(i, j int) bool { return msgs[i].At < msgs[j].At }) - for _, msg := range msgs { - tm := time.Unix(msg.At/1000, (msg.At%1000)*1000) - fmtTime := tm.Format("15:04:05") - fmt.Fprintf(t.logOptions.Writer, "[%s] %-7s %s\n", fmtTime, msg.Type, msg.Message) - } - return response.LastID -} - -func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { - deploymentURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s", - t.apiOptions.System.URL, - t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, - t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region) - req, err := http.NewRequest("GET", deploymentURL, nil) - if err != nil { - return err - } - if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { - return err - } - urlsByCluster := make(map[string]string) - endpointFunc := func(status int, response []byte) (bool, error) { - if ok, err := isOK(status); !ok { - return ok, err - } - var resp deploymentResponse - if err := json.Unmarshal(response, &resp); err != nil { - return false, nil - } - if len(resp.Endpoints) == 0 { - return false, nil - } - for _, endpoint := range resp.Endpoints { - if endpoint.Scope != "zone" { - continue - } - urlsByCluster[endpoint.Cluster] = endpoint.URL - } - return true, nil - } - if _, err = wait(endpointFunc, func() *http.Request { return req }, &t.apiOptions.TLSOptions.KeyPair, timeout); err != nil { - return err - } - if len(urlsByCluster) == 0 { - return fmt.Errorf("no endpoints discovered") - } - t.deploymentOptions.ClusterURLs = urlsByCluster - return nil -} - -func isOK(status int) (bool, error) { - if status == 401 { - return false, fmt.Errorf("status %d: invalid api key", status) - } - return status/100 == 2, nil -} - -// LocalTarget creates a target for a Vespa platform running locally. -func LocalTarget() Target { - return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1"} -} - -// CustomTarget creates a Target for a Vespa platform running at baseURL. -func CustomTarget(baseURL string) Target { - return &customTarget{targetType: TargetCustom, baseURL: baseURL} -} - -// CloudTarget creates a Target for the Vespa Cloud or hosted Vespa platform. -func CloudTarget(apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) { - ztsClient, err := zts.NewClient(zts.DefaultURL, util.ActiveHttpClient) - if err != nil { - return nil, err + return 0, err } - return &cloudTarget{ - apiOptions: apiOptions, - deploymentOptions: deploymentOptions, - logOptions: logOptions, - ztsClient: ztsClient, - }, nil -} - -type deploymentEndpoint struct { - Cluster string `json:"cluster"` - URL string `json:"url"` - Scope string `json:"scope"` -} - -type deploymentResponse struct { - Endpoints []deploymentEndpoint `json:"endpoints"` + okFunc := func(status int, response []byte) (bool, error) { return isOK(status), nil } + return wait(okFunc, func() *http.Request { return req }, certificate, timeout) } -type serviceConvergeResponse struct { - Converged bool `json:"converged"` -} - -type jobResponse struct { - Active bool `json:"active"` - Status string `json:"status"` - Log map[string][]logMessage `json:"log"` - LastID int64 `json:"lastId"` -} - -type logMessage struct { - At int64 `json:"at"` - Type string `json:"type"` - Message string `json:"message"` -} - -type responseFunc func(status int, response []byte) (bool, error) - -type requestFunc func() *http.Request - func wait(fn responseFunc, reqFn requestFunc, certificate *tls.Certificate, timeout time.Duration) (int, error) { if certificate != nil { util.ActiveHttpClient.UseCertificate([]tls.Certificate{*certificate}) diff --git a/client/go/vespa/target_cloud.go b/client/go/vespa/target_cloud.go new file mode 100644 index 00000000000..f4eccaacab6 --- /dev/null +++ b/client/go/vespa/target_cloud.go @@ -0,0 +1,382 @@ +package vespa + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "math" + "net/http" + "sort" + "strconv" + "time" + + "github.com/vespa-engine/vespa/client/go/auth0" + "github.com/vespa-engine/vespa/client/go/util" + "github.com/vespa-engine/vespa/client/go/version" + "github.com/vespa-engine/vespa/client/go/zts" +) + +// CloudOptions configures URL and authentication for a cloud target. +type APIOptions struct { + System System + TLSOptions TLSOptions + APIKey []byte + AuthConfigPath string +} + +// CloudDeploymentOptions configures the deployment to manage through a cloud target. +type CloudDeploymentOptions struct { + Deployment Deployment + TLSOptions TLSOptions + ClusterURLs map[string]string // Endpoints keyed on cluster name +} + +type cloudTarget struct { + apiOptions APIOptions + deploymentOptions CloudDeploymentOptions + logOptions LogOptions + ztsClient ztsClient +} + +type deploymentEndpoint struct { + Cluster string `json:"cluster"` + URL string `json:"url"` + Scope string `json:"scope"` +} + +type deploymentResponse struct { + Endpoints []deploymentEndpoint `json:"endpoints"` +} + +type jobResponse struct { + Active bool `json:"active"` + Status string `json:"status"` + Log map[string][]logMessage `json:"log"` + LastID int64 `json:"lastId"` +} + +type logMessage struct { + At int64 `json:"at"` + Type string `json:"type"` + Message string `json:"message"` +} + +type ztsClient interface { + AccessToken(domain string, certficiate tls.Certificate) (string, error) +} + +// CloudTarget creates a Target for the Vespa Cloud or hosted Vespa platform. +func CloudTarget(apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) { + ztsClient, err := zts.NewClient(zts.DefaultURL, util.ActiveHttpClient) + if err != nil { + return nil, err + } + return &cloudTarget{ + apiOptions: apiOptions, + deploymentOptions: deploymentOptions, + logOptions: logOptions, + ztsClient: ztsClient, + }, nil +} + +func (t *cloudTarget) resolveEndpoint(cluster string) (string, error) { + if cluster == "" { + for _, u := range t.deploymentOptions.ClusterURLs { + if len(t.deploymentOptions.ClusterURLs) == 1 { + return u, nil + } else { + return "", fmt.Errorf("multiple clusters, none chosen: %v", t.deploymentOptions.ClusterURLs) + } + } + } else { + u := t.deploymentOptions.ClusterURLs[cluster] + if u == "" { + clusters := make([]string, len(t.deploymentOptions.ClusterURLs)) + for c := range t.deploymentOptions.ClusterURLs { + clusters = append(clusters, c) + } + return "", fmt.Errorf("unknown cluster '%s': must be one of %v", cluster, clusters) + } + return u, nil + } + + return "", fmt.Errorf("no endpoints") +} + +func (t *cloudTarget) Type() string { + switch t.apiOptions.System.Name { + case MainSystem.Name, CDSystem.Name: + return TargetHosted + } + return TargetCloud +} + +func (t *cloudTarget) Deployment() Deployment { return t.deploymentOptions.Deployment } + +func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, cluster string) (*Service, error) { + switch name { + case DeployService: + service := &Service{Name: name, BaseURL: t.apiOptions.System.URL, TLSOptions: t.apiOptions.TLSOptions, ztsClient: t.ztsClient} + if timeout > 0 { + status, err := service.Wait(timeout) + if err != nil { + return nil, err + } + if !isOK(status) { + return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL) + } + } + return service, nil + case QueryService, DocumentService: + if t.deploymentOptions.ClusterURLs == nil { + if err := t.waitForEndpoints(timeout, runID); err != nil { + return nil, err + } + } + url, err := t.resolveEndpoint(cluster) + if err != nil { + return nil, err + } + t.deploymentOptions.TLSOptions.AthenzDomain = t.apiOptions.System.AthenzDomain + return &Service{Name: name, BaseURL: url, TLSOptions: t.deploymentOptions.TLSOptions, ztsClient: t.ztsClient}, nil + } + return nil, fmt.Errorf("unknown service: %s", name) +} + +func (t *cloudTarget) SignRequest(req *http.Request, keyID string) error { + if t.apiOptions.System.IsPublic() { + if t.apiOptions.APIKey != nil { + signer := NewRequestSigner(keyID, t.apiOptions.APIKey) + return signer.SignRequest(req) + } else { + return t.addAuth0AccessToken(req) + } + } else { + if t.apiOptions.TLSOptions.KeyPair.Certificate == nil { + return fmt.Errorf("system %s requires a certificate for authentication", t.apiOptions.System.Name) + } + return nil + } +} + +func (t *cloudTarget) CheckVersion(clientVersion version.Version) error { + if clientVersion.IsZero() { // development version is always fine + return nil + } + req, err := http.NewRequest("GET", fmt.Sprintf("%s/cli/v1/", t.apiOptions.System.URL), nil) + if err != nil { + return err + } + response, err := util.HttpDo(req, 10*time.Second, "") + if err != nil { + return err + } + defer response.Body.Close() + var cliResponse struct { + MinVersion string `json:"minVersion"` + } + dec := json.NewDecoder(response.Body) + if err := dec.Decode(&cliResponse); err != nil { + return err + } + minVersion, err := version.Parse(cliResponse.MinVersion) + if err != nil { + return err + } + if clientVersion.Less(minVersion) { + return fmt.Errorf("client version %s is less than the minimum supported version: %s", clientVersion, minVersion) + } + return nil +} + +func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error { + a, err := auth0.GetAuth0(t.apiOptions.AuthConfigPath, t.apiOptions.System.Name, t.apiOptions.System.URL) + if err != nil { + return err + } + system, err := a.PrepareSystem(auth0.ContextWithCancel()) + if err != nil { + return err + } + request.Header.Set("Authorization", "Bearer "+system.AccessToken) + return nil +} + +func (t *cloudTarget) logsURL() string { + return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs", + t.apiOptions.System.URL, + t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, + t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region) +} + +func (t *cloudTarget) PrintLog(options LogOptions) error { + req, err := http.NewRequest("GET", t.logsURL(), nil) + if err != nil { + return err + } + lastFrom := options.From + requestFunc := func() *http.Request { + fromMillis := lastFrom.Unix() * 1000 + q := req.URL.Query() + q.Set("from", strconv.FormatInt(fromMillis, 10)) + if !options.To.IsZero() { + toMillis := options.To.Unix() * 1000 + q.Set("to", strconv.FormatInt(toMillis, 10)) + } + req.URL.RawQuery = q.Encode() + t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()) + return req + } + logFunc := func(status int, response []byte) (bool, error) { + if ok, err := isCloudOK(status); !ok { + return ok, err + } + logEntries, err := ReadLogEntries(bytes.NewReader(response)) + if err != nil { + return true, err + } + for _, le := range logEntries { + if !le.Time.After(lastFrom) { + continue + } + if LogLevel(le.Level) > options.Level { + continue + } + fmt.Fprintln(options.Writer, le.Format(options.Dequote)) + } + if len(logEntries) > 0 { + lastFrom = logEntries[len(logEntries)-1].Time + } + return false, nil + } + var timeout time.Duration + if options.Follow { + timeout = math.MaxInt64 // No timeout + } + _, err = wait(logFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout) + return err +} + +func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error { + if runID > 0 { + if err := t.waitForRun(runID, timeout); err != nil { + return err + } + } + return t.discoverEndpoints(timeout) +} + +func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { + runURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/job/%s-%s/run/%d", + t.apiOptions.System.URL, + t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, + t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region, runID) + req, err := http.NewRequest("GET", runURL, nil) + if err != nil { + return err + } + lastID := int64(-1) + requestFunc := func() *http.Request { + q := req.URL.Query() + q.Set("after", strconv.FormatInt(lastID, 10)) + req.URL.RawQuery = q.Encode() + if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { + panic(err) + } + return req + } + jobSuccessFunc := func(status int, response []byte) (bool, error) { + if ok, err := isCloudOK(status); !ok { + return ok, err + } + var resp jobResponse + if err := json.Unmarshal(response, &resp); err != nil { + return false, nil + } + if t.logOptions.Writer != nil { + lastID = t.printLog(resp, lastID) + } + if resp.Active { + return false, nil + } + if resp.Status != "success" { + return false, fmt.Errorf("run %d ended with unsuccessful status: %s", runID, resp.Status) + } + return true, nil + } + _, err = wait(jobSuccessFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout) + return err +} + +func (t *cloudTarget) printLog(response jobResponse, last int64) int64 { + if response.LastID == 0 { + return last + } + var msgs []logMessage + for step, stepMsgs := range response.Log { + for _, msg := range stepMsgs { + if step == "copyVespaLogs" && LogLevel(msg.Type) > t.logOptions.Level || LogLevel(msg.Type) == 3 { + continue + } + msgs = append(msgs, msg) + } + } + sort.Slice(msgs, func(i, j int) bool { return msgs[i].At < msgs[j].At }) + for _, msg := range msgs { + tm := time.Unix(msg.At/1000, (msg.At%1000)*1000) + fmtTime := tm.Format("15:04:05") + fmt.Fprintf(t.logOptions.Writer, "[%s] %-7s %s\n", fmtTime, msg.Type, msg.Message) + } + return response.LastID +} + +func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { + deploymentURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s", + t.apiOptions.System.URL, + t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, + t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region) + req, err := http.NewRequest("GET", deploymentURL, nil) + if err != nil { + return err + } + if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { + return err + } + urlsByCluster := make(map[string]string) + endpointFunc := func(status int, response []byte) (bool, error) { + if ok, err := isCloudOK(status); !ok { + return ok, err + } + var resp deploymentResponse + if err := json.Unmarshal(response, &resp); err != nil { + return false, nil + } + if len(resp.Endpoints) == 0 { + return false, nil + } + for _, endpoint := range resp.Endpoints { + if endpoint.Scope != "zone" { + continue + } + urlsByCluster[endpoint.Cluster] = endpoint.URL + } + return true, nil + } + if _, err = wait(endpointFunc, func() *http.Request { return req }, &t.apiOptions.TLSOptions.KeyPair, timeout); err != nil { + return err + } + if len(urlsByCluster) == 0 { + return fmt.Errorf("no endpoints discovered") + } + t.deploymentOptions.ClusterURLs = urlsByCluster + return nil +} + +func isCloudOK(status int) (bool, error) { + if status == 401 { + // when retrying we should give up immediately if we're not authorized + return false, fmt.Errorf("status %d: invalid credentials", status) + } + return isOK(status), nil +} diff --git a/client/go/vespa/target_custom.go b/client/go/vespa/target_custom.go new file mode 100644 index 00000000000..072ec8649e4 --- /dev/null +++ b/client/go/vespa/target_custom.go @@ -0,0 +1,128 @@ +package vespa + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/vespa-engine/vespa/client/go/version" +) + +type customTarget struct { + targetType string + baseURL string +} + +type serviceConvergeResponse struct { + Converged bool `json:"converged"` +} + +// LocalTarget creates a target for a Vespa platform running locally. +func LocalTarget() Target { + return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1"} +} + +// CustomTarget creates a Target for a Vespa platform running at baseURL. +func CustomTarget(baseURL string) Target { + return &customTarget{targetType: TargetCustom, baseURL: baseURL} +} + +func (t *customTarget) Type() string { return t.targetType } + +func (t *customTarget) Deployment() Deployment { return Deployment{} } + +func (t *customTarget) createService(name string) (*Service, error) { + switch name { + case DeployService, QueryService, DocumentService: + url, err := t.urlWithPort(name) + if err != nil { + return nil, err + } + return &Service{BaseURL: url, Name: name}, nil + } + return nil, fmt.Errorf("unknown service: %s", name) +} + +func (t *customTarget) Service(name string, timeout time.Duration, sessionOrRunID int64, cluster string) (*Service, error) { + service, err := t.createService(name) + if err != nil { + return nil, err + } + if timeout > 0 { + if name == DeployService { + status, err := service.Wait(timeout) + if err != nil { + return nil, err + } + if !isOK(status) { + return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL) + } + } else { + if err := t.waitForConvergence(timeout); err != nil { + return nil, err + } + } + } + return service, nil +} + +func (t *customTarget) PrintLog(options LogOptions) error { + return fmt.Errorf("reading logs from non-cloud deployment is unsupported") +} + +func (t *customTarget) SignRequest(req *http.Request, sigKeyId string) error { return nil } + +func (t *customTarget) CheckVersion(version version.Version) error { return nil } + +func (t *customTarget) urlWithPort(serviceName string) (string, error) { + u, err := url.Parse(t.baseURL) + if err != nil { + return "", err + } + port := u.Port() + if port == "" { + switch serviceName { + case DeployService: + port = "19071" + case QueryService, DocumentService: + port = "8080" + default: + return "", fmt.Errorf("unknown service: %s", serviceName) + } + u.Host = u.Host + ":" + port + } + return u.String(), nil +} + +func (t *customTarget) waitForConvergence(timeout time.Duration) error { + deployURL, err := t.urlWithPort(DeployService) + if err != nil { + return err + } + url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployURL) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return err + } + converged := false + convergedFunc := func(status int, response []byte) (bool, error) { + if !isOK(status) { + return false, nil + } + var resp serviceConvergeResponse + if err := json.Unmarshal(response, &resp); err != nil { + return false, nil + } + converged = resp.Converged + return converged, nil + } + if _, err := wait(convergedFunc, func() *http.Request { return req }, nil, timeout); err != nil { + return err + } + if !converged { + return fmt.Errorf("services have not converged") + } + return nil +} diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java index e0e4318ccc8..7cb36374568 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java @@ -116,6 +116,7 @@ public interface ModelContext { @ModelFeatureFlag(owners = {"arnej"}) default boolean useQrserverServiceName() { return true; } @ModelFeatureFlag(owners = {"bjorncs", "baldersheim"}) default boolean enableJdiscPreshutdownCommand() { return true; } @ModelFeatureFlag(owners = {"arnej"}) default boolean avoidRenamingSummaryFeatures() { return false; } + @ModelFeatureFlag(owners = {"bjorncs", "baldersheim"}) default boolean mergeGroupingResultInSearchInvoker() { return false; } } /** Warning: As elsewhere in this package, do not make backwards incompatible changes that will break old config models! */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java index 22b752777e9..e483351a25a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java @@ -158,7 +158,7 @@ public class ContentSearchCluster extends AbstractConfigProducer<SearchCluster> String clusterName, ContentSearchCluster search) { List<ModelElement> indexedDefs = getIndexedSchemas(clusterElem); if (!indexedDefs.isEmpty()) { - IndexedSearchCluster isc = new IndexedSearchCluster(search, clusterName, 0); + IndexedSearchCluster isc = new IndexedSearchCluster(deployState, search, clusterName, 0); isc.setRoutingSelector(clusterElem.childAsString("documents.selection")); Double visibilityDelay = clusterElem.childAsDouble("engine.proton.visibility-delay"); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java index fb7c6696b54..53aac23135a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java @@ -6,7 +6,6 @@ import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig; import com.yahoo.search.config.IndexInfoConfig; import com.yahoo.searchdefinition.DocumentOnlySchema; -import com.yahoo.searchdefinition.Schema; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.vespa.config.search.DispatchConfig; @@ -34,6 +33,7 @@ public class IndexedSearchCluster extends SearchCluster IlscriptsConfig.Producer, DispatchConfig.Producer { + private final boolean mergeGroupingResultInSearchInvoker; private String indexingClusterName = null; // The name of the docproc cluster to run indexing, by config. private String indexingChainName = null; @@ -63,10 +63,11 @@ public class IndexedSearchCluster extends SearchCluster return routingSelector; } - public IndexedSearchCluster(AbstractConfigProducer<SearchCluster> parent, String clusterName, int index) { + public IndexedSearchCluster(DeployState deployState, AbstractConfigProducer<SearchCluster> parent, String clusterName, int index) { super(parent, clusterName, index); unionCfg = new UnionConfiguration(this, documentDbs); rootDispatch = new DispatchGroup(this); + mergeGroupingResultInSearchInvoker = deployState.featureFlags().mergeGroupingResultInSearchInvoker(); } @Override @@ -320,6 +321,7 @@ public class IndexedSearchCluster extends SearchCluster builder.maxWaitAfterCoverageFactor(searchCoverage.getMaxWaitAfterCoverageFactor()); } builder.warmuptime(5.0); + builder.mergeGroupingResultInSearchInvokerEnabled(mergeGroupingResultInSearchInvoker); } @Override diff --git a/config-model/src/main/resources/schema/deployment.rnc b/config-model/src/main/resources/schema/deployment.rnc index 3e751a379d4..1aaf002b703 100644 --- a/config-model/src/main/resources/schema/deployment.rnc +++ b/config-model/src/main/resources/schema/deployment.rnc @@ -52,7 +52,8 @@ ParallelInstances = element parallel { Upgrade = element upgrade { attribute policy { xsd:string }? & - attribute revision { xsd:string }? & + attribute revision-target { xsd:string }? & + attribute revision-change { xsd:string }? & attribute rollout { xsd:string }? } diff --git a/configdefinitions/src/vespa/dispatch.def b/configdefinitions/src/vespa/dispatch.def index 17f42a73bfd..fef9300a410 100644 --- a/configdefinitions/src/vespa/dispatch.def +++ b/configdefinitions/src/vespa/dispatch.def @@ -71,3 +71,6 @@ node[].host string # The rpc port of this search node node[].port int + +# Temporary feature flag +mergeGroupingResultInSearchInvokerEnabled bool default=false diff --git a/configdefinitions/src/vespa/stor-filestor.def b/configdefinitions/src/vespa/stor-filestor.def index e54b503ed93..531805d3039 100644 --- a/configdefinitions/src/vespa/stor-filestor.def +++ b/configdefinitions/src/vespa/stor-filestor.def @@ -80,7 +80,7 @@ resource_usage_reporter_noise_level double default=0.001 ## - DYNAMIC uses DynamicThrottlePolicy under the hood and will block if the window ## is full (if a blocking throttler API call is invoked). ## -async_operation_throttler.type enum { UNLIMITED, DYNAMIC } default=UNLIMITED restart +async_operation_throttler.type enum { UNLIMITED, DYNAMIC } default=UNLIMITED ## Internal throttler tuning parameters that only apply when type == DYNAMIC: async_operation_throttler.window_size_increment int default=20 async_operation_throttler.window_size_decrement_factor double default=1.2 @@ -104,7 +104,7 @@ async_operation_throttler.throttle_individual_merge_feed_ops bool default=true ## is full (if a blocking throttler API call is invoked). ## ## TODO deprecate in favor of the async_operation_throttler struct instead. -async_operation_throttler_type enum { UNLIMITED, DYNAMIC } default=UNLIMITED restart +async_operation_throttler_type enum { UNLIMITED, DYNAMIC } default=UNLIMITED ## Specifies the extent the throttling window is increased by when the async throttle ## policy has decided that more concurrent operations are desirable. Also affects the diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index 5a813c7886a..6222e7b1788 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -209,6 +209,7 @@ public class ModelContextImpl implements ModelContext { private final boolean inhibitDefaultMergesWhenGlobalMergesPending; private final boolean useQrserverServiceName; private final boolean avoidRenamingSummaryFeatures; + private final boolean mergeGroupingResultInSearchInvoker; public FeatureFlags(FlagSource source, ApplicationId appId) { this.defaultTermwiseLimit = flagValue(source, appId, Flags.DEFAULT_TERM_WISE_LIMIT); @@ -256,6 +257,7 @@ public class ModelContextImpl implements ModelContext { this.inhibitDefaultMergesWhenGlobalMergesPending = flagValue(source, appId, Flags.INHIBIT_DEFAULT_MERGES_WHEN_GLOBAL_MERGES_PENDING); this.useQrserverServiceName = flagValue(source, appId, Flags.USE_QRSERVER_SERVICE_NAME); this.avoidRenamingSummaryFeatures = flagValue(source, appId, Flags.AVOID_RENAMING_SUMMARY_FEATURES); + this.mergeGroupingResultInSearchInvoker = flagValue(source, appId, Flags.MERGE_GROUPING_RESULT_IN_SEARCH_INVOKER); } @Override public double defaultTermwiseLimit() { return defaultTermwiseLimit; } @@ -305,6 +307,7 @@ public class ModelContextImpl implements ModelContext { @Override public boolean inhibitDefaultMergesWhenGlobalMergesPending() { return inhibitDefaultMergesWhenGlobalMergesPending; } @Override public boolean useQrserverServiceName() { return useQrserverServiceName; } @Override public boolean avoidRenamingSummaryFeatures() { return avoidRenamingSummaryFeatures; } + @Override public boolean mergeGroupingResultInSearchInvoker() { return mergeGroupingResultInSearchInvoker; } private static <V> V flagValue(FlagSource source, ApplicationId appId, UnboundFlag<? extends V, ?, ?> flag) { return flag.bindTo(source) diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/GroupingResultAggregator.java b/container-search/src/main/java/com/yahoo/search/dispatch/GroupingResultAggregator.java new file mode 100644 index 00000000000..5ce7accfdd4 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/GroupingResultAggregator.java @@ -0,0 +1,50 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import com.yahoo.prelude.fastsearch.DocsumDefinitionSet; +import com.yahoo.prelude.fastsearch.GroupingListHit; +import com.yahoo.searchlib.aggregation.Grouping; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Incrementally merges underlying {@link Grouping} instances from {@link GroupingListHit} hits. + * + * @author bjorncs + */ +class GroupingResultAggregator { + private static final Logger log = Logger.getLogger(GroupingResultAggregator.class.getName()); + + private final Map<Integer, Grouping> groupings = new LinkedHashMap<>(); + private DocsumDefinitionSet documentDefinitions = null; + private int groupingHitsMerged = 0; + + void mergeWith(GroupingListHit result) { + if (groupingHitsMerged == 0) documentDefinitions = result.getDocsumDefinitionSet(); + ++groupingHitsMerged; + log.log(Level.FINE, () -> + String.format("Merging hit #%d having %d groupings", + groupingHitsMerged, result.getGroupingList().size())); + for (Grouping grouping : result.getGroupingList()) { + groupings.merge(grouping.getId(), grouping, (existingGrouping, newGrouping) -> { + existingGrouping.merge(newGrouping); + return existingGrouping; + }); + } + } + + Optional<GroupingListHit> toAggregatedHit() { + if (groupingHitsMerged == 0) return Optional.empty(); + log.log(Level.FINE, () -> + String.format("Creating aggregated hit containing %d groupings from %d hits", + groupings.size(), groupingHitsMerged)); + groupings.values().forEach(Grouping::postMerge); + return Optional.of(new GroupingListHit(List.copyOf(groupings.values()), documentDefinitions)); + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java index 4e658122cdf..d7c9f1dce53 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.dispatch; +import com.yahoo.prelude.fastsearch.GroupingListHit; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.dispatch.searchcluster.Group; @@ -44,6 +45,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM private final Group group; private final LinkedBlockingQueue<SearchInvoker> availableForProcessing; private final Set<Integer> alreadyFailedNodes; + private final boolean mergeGroupingResult; private Query query; private boolean adaptiveTimeoutCalculated = false; @@ -71,6 +73,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM this.group = group; this.availableForProcessing = newQueue(); this.alreadyFailedNodes = alreadyFailedNodes; + this.mergeGroupingResult = searchCluster.dispatchConfig().mergeGroupingResultInSearchInvokerEnabled(); } /** @@ -115,6 +118,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM long nextTimeout = query.getTimeLeft(); boolean extraDebug = (query.getOffset() == 0) && (query.getHits() == 7) && log.isLoggable(java.util.logging.Level.FINE); List<InvokerResult> processed = new ArrayList<>(); + var groupingResultAggregator = new GroupingResultAggregator(); try { while (!invokers.isEmpty() && nextTimeout >= 0) { SearchInvoker invoker = availableForProcessing.poll(nextTimeout, TimeUnit.MILLISECONDS); @@ -126,7 +130,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM if (extraDebug) { processed.add(toMerge); } - merged = mergeResult(result.getResult(), toMerge, merged); + merged = mergeResult(result.getResult(), toMerge, merged, groupingResultAggregator); ejectInvoker(invoker); } nextTimeout = nextTimeout(); @@ -134,6 +138,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM } catch (InterruptedException e) { throw new RuntimeException("Interrupted while waiting for search results", e); } + groupingResultAggregator.toAggregatedHit().ifPresent(h -> result.getResult().hits().add(h)); insertNetworkErrors(result.getResult()); result.getResult().setCoverage(createCoverage()); @@ -238,14 +243,20 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM return nextAdaptive; } - private List<LeanHit> mergeResult(Result result, InvokerResult partialResult, List<LeanHit> current) { + private List<LeanHit> mergeResult(Result result, InvokerResult partialResult, List<LeanHit> current, + GroupingResultAggregator groupingResultAggregator) { collectCoverage(partialResult.getResult().getCoverage(true)); result.mergeWith(partialResult.getResult()); List<Hit> partialNonLean = partialResult.getResult().hits().asUnorderedHits(); for(Hit hit : partialNonLean) { if (hit.isAuxiliary()) { - result.hits().add(hit); + if (hit instanceof GroupingListHit && mergeGroupingResult) { + var groupingHit = (GroupingListHit) hit; + groupingResultAggregator.mergeWith(groupingHit); + } else { + result.hits().add(hit); + } } } if (current.isEmpty() ) { diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java index 6e08b1c6fa5..347276d680d 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java @@ -4,6 +4,7 @@ package com.yahoo.search.dispatch; import com.yahoo.document.GlobalId; import com.yahoo.document.idstring.IdString; import com.yahoo.prelude.fastsearch.FastHit; +import com.yahoo.prelude.fastsearch.GroupingListHit; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.dispatch.searchcluster.Group; @@ -13,6 +14,11 @@ import com.yahoo.search.result.Coverage; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.Hit; import com.yahoo.search.result.Relevance; +import com.yahoo.searchlib.aggregation.Grouping; +import com.yahoo.searchlib.aggregation.MaxAggregationResult; +import com.yahoo.searchlib.aggregation.MinAggregationResult; +import com.yahoo.searchlib.expression.IntegerResultNode; +import com.yahoo.searchlib.expression.StringResultNode; import com.yahoo.test.ManualClock; import org.junit.Test; @@ -320,6 +326,39 @@ public class InterleavedSearchInvokerTest { assertEquals(3, result.getQuery().getHits()); } + @Test + public void requireThatGroupingsAreMerged() throws IOException { + SearchCluster cluster = new MockSearchCluster("!", 1, 2); + List<SearchInvoker> invokers = new ArrayList<>(); + + Grouping grouping1 = new Grouping(0); + grouping1.setRoot(new com.yahoo.searchlib.aggregation.Group() + .addChild(new com.yahoo.searchlib.aggregation.Group() + .setId(new StringResultNode("uniqueA")) + .addAggregationResult(new MaxAggregationResult().setMax(new IntegerResultNode(6)).setTag(4))) + .addChild(new com.yahoo.searchlib.aggregation.Group() + .setId(new StringResultNode("common")) + .addAggregationResult(new MaxAggregationResult().setMax(new IntegerResultNode(9)).setTag(4)))); + invokers.add(new MockInvoker(0).setHits(List.of(new GroupingListHit(List.of(grouping1))))); + + Grouping grouping2 = new Grouping(0); + grouping2.setRoot(new com.yahoo.searchlib.aggregation.Group() + .addChild(new com.yahoo.searchlib.aggregation.Group() + .setId(new StringResultNode("uniqueB")) + .addAggregationResult(new MaxAggregationResult().setMax(new IntegerResultNode(9)).setTag(4))) + .addChild(new com.yahoo.searchlib.aggregation.Group() + .setId(new StringResultNode("common")) + .addAggregationResult(new MinAggregationResult().setMin(new IntegerResultNode(6)).setTag(3)))); + invokers.add(new MockInvoker(0).setHits(List.of(new GroupingListHit(List.of(grouping2))))); + + InterleavedSearchInvoker invoker = new InterleavedSearchInvoker(invokers, cluster, new Group(0, List.of()), Collections.emptySet()); + invoker.responseAvailable(invokers.get(0)); + invoker.responseAvailable(invokers.get(1)); + Result result = invoker.search(query, null); + assertEquals(1, ((GroupingListHit) result.hits().get(0)).getGroupingList().size()); + + } + private static InterleavedSearchInvoker createInterLeavedTestInvoker(List<Double> a, List<Double> b, Group group) { SearchCluster cluster = new MockSearchCluster("!", 1, 2); List<SearchInvoker> invokers = new ArrayList<>(); diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java index 2acce0f8d2d..54c8c1e0522 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java @@ -120,7 +120,8 @@ public class MockSearchCluster extends SearchCluster { builder.minActivedocsPercentage(88.0); builder.minGroupCoverage(99.0); builder.minSearchCoverage(minSearchCoverage); - builder.distributionPolicy(DispatchConfig.DistributionPolicy.Enum.ROUNDROBIN); + builder.distributionPolicy(DispatchConfig.DistributionPolicy.Enum.ROUNDROBIN) + .mergeGroupingResultInSearchInvokerEnabled(true); if (minSearchCoverage < 100.0) { builder.minWaitAfterCoverageFactor(0); builder.maxWaitAfterCoverageFactor(0.5); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java index f36f5be7778..8c933f98277 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java @@ -176,10 +176,14 @@ public class DeploymentStatus { Map<JobId, List<Job>> jobs = jobsToRun(changes); // Add test jobs for any outstanding change. - for (InstanceName instance : application.deploymentSpec().instanceNames()) - changes.put(instance, outstandingChange(instance).onTopOf(application.require(instance).change())); - var testJobs = jobsToRun(changes, true).entrySet().stream() - .filter(entry -> ! entry.getKey().type().isProduction()); + Map<InstanceName, Change> outstandingChanges = new LinkedHashMap<>(); + for (InstanceName instance : application.deploymentSpec().instanceNames()) { + Change outstanding = outstandingChange(instance); + if (outstanding.hasTargets()) + outstandingChanges.put(instance, outstanding.onTopOf(application.require(instance).change())); + } + var testJobs = jobsToRun(outstandingChanges, true).entrySet().stream() + .filter(entry -> ! entry.getKey().type().isProduction()); return Stream.concat(jobs.entrySet().stream(), testJobs) .collect(collectingAndThen(toMap(Map.Entry::getKey, @@ -347,8 +351,8 @@ public class DeploymentStatus { || step.completedAt(change.withoutPlatform(), Optional.of(job)).isPresent()) return List.of(change); - // For a dual change, where both target remain, we determine what to run by looking at when the two parts became ready: - // for deployments, we look at dependencies; for tests, this may be overridden by what is already deployed. + // For a dual change, where both targets remain, we determine what to run by looking at when the two parts became ready: + // for deployments, we look at dependencies; for production tests, this may be overridden by what is already deployed. JobId deployment = new JobId(job.application(), JobType.from(system, job.type().zone(system)).get()); UpgradeRollout rollout = application.deploymentSpec().requireInstance(job.application().instance()).upgradeRollout(); if (job.type().isTest()) { @@ -405,10 +409,13 @@ public class DeploymentStatus { // Both changes are ready for this step, and we look to the specified rollout to decide. boolean platformReadyFirst = platformReadyAt.get().isBefore(revisionReadyAt.get()); boolean revisionReadyFirst = revisionReadyAt.get().isBefore(platformReadyAt.get()); + boolean failingUpgradeOnlyTests = ! jobs().type(systemTest, stagingTest) + .failingHardOn(Versions.from(change.withoutApplication(), application, deploymentFor(job), systemVersion)) + .isEmpty(); switch (rollout) { case separate: // Let whichever change rolled out first, keep rolling first, unless upgrade alone is failing. return (platformReadyFirst || platformReadyAt.get().equals(Instant.EPOCH)) // Assume platform was first if no jobs have run yet. - ? step.job().flatMap(jobs()::get).flatMap(JobStatus::firstFailing).isPresent() + ? step.job().flatMap(jobs()::get).flatMap(JobStatus::firstFailing).isPresent() || failingUpgradeOnlyTests ? List.of(change) // Platform was first, but is failing. : List.of(change.withoutApplication(), change) // Platform was first, and is OK. : revisionReadyFirst @@ -900,6 +907,11 @@ public class DeploymentStatus { return Objects.hash(versions, readyAt, change); } + @Override + public String toString() { + return change + " with versions " + versions + ", ready at " + readyAt; + } + } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java index ed5df62ca5d..aeaa821745b 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java @@ -372,11 +372,11 @@ public class DeploymentTrigger { .ifPresent(last -> { if (jobs.get(job).stream().noneMatch(versions -> versions.versions().targetsMatch(last.versions()) && versions.versions().sourcesMatchIfPresent(last.versions()))) { - log.log(Level.INFO, "Aborting outdated run " + last); - controller.jobController().abort(last.id(), "run no longer scheduled, and is blocking scheduled runs: " + - jobs.get(job).stream() - .map(scheduled -> scheduled.versions().toString()) - .collect(Collectors.joining(", "))); + String blocked = jobs.get(job).stream() + .map(scheduled -> scheduled.versions().toString()) + .collect(Collectors.joining(", ")); + log.log(Level.INFO, "Aborting outdated run " + last + ", which is blocking runs: " + blocked); + controller.jobController().abort(last.id(), "run no longer scheduled, and is blocking scheduled runs: " + blocked); } }); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobList.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobList.java index 5de07bad859..d06bdc45583 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobList.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobList.java @@ -17,6 +17,7 @@ import java.util.function.Function; import java.util.function.Predicate; import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.aborted; +import static com.yahoo.vespa.hosted.controller.deployment.RunStatus.nodeAllocationFailure; /** * A list of deployment jobs that can be filtered in various ways. @@ -102,6 +103,15 @@ public class JobList extends AbstractFilteringList<JobStatus, JobList> { return matching(job -> job.id().type().isProduction()); } + /** Returns the jobs with any runs failing with non-out-of-test-capacity on the given versions — targets only for system test, everything present otherwise. */ + public JobList failingHardOn(Versions versions) { + return matching(job -> ! RunList.from(job) + .on(versions) + .matching(Run::hasFailed) + .not().matching(run -> run.status() == nodeAllocationFailure && run.id().type().environment().isTest()) + .isEmpty()); + } + /** Returns the jobs with any runs matching the given versions — targets only for system test, everything present otherwise. */ public JobList triggeredOn(Versions versions) { return matching(job -> ! RunList.from(job).on(versions).isEmpty()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java index b95d34f5414..b2847a29654 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java @@ -698,7 +698,10 @@ public class DeploymentTriggerTest { ApplicationVersion revision1 = app1.lastSubmission().get(); app1.submit(applicationPackage); ApplicationVersion revision2 = app1.lastSubmission().get(); - app1.runJob(systemTest).runJob(stagingTest); + app1.runJob(systemTest) // Tests for new revision on version2 + .runJob(stagingTest) + .runJob(systemTest) // Tests for new revision on version1 + .runJob(stagingTest); assertEquals(Change.of(version1).with(revision2), app1.instance().change()); tester.triggerJobs(); app1.assertRunning(productionUsCentral1); @@ -718,9 +721,7 @@ public class DeploymentTriggerTest { app1.assertNotRunning(productionUsCentral1); // Last job has a different deployment target, so tests need to run again. - app1.runJob(systemTest) - .runJob(stagingTest) // Eager test of outstanding change, assuming upgrade in west succeeds. - .runJob(productionEuWest1) // Upgrade completes, and revision is the only change. + app1.runJob(productionEuWest1) // Upgrade completes, and revision is the only change. .runJob(productionUsCentral1) // With only revision change, central should run to cover a previous failure. .runJob(productionEuWest1); // Finally, west changes revision. assertEquals(Change.empty(), app1.instance().change()); @@ -1433,6 +1434,21 @@ public class DeploymentTriggerTest { tester.jobs().last(app.instanceId(), productionUsWest1).get().versions()); app.runJob(productionUsWest1); assertEquals(Change.empty(), app.instance().change()); + + // New upgrade fails in staging-test, and revision to fix it is submitted. + var version2 = new Version("7.2"); + tester.controllerTester().upgradeSystem(version2); + tester.upgrader().maintain(); + app.runJob(systemTest).failDeployment(stagingTest); + tester.clock().advance(Duration.ofMinutes(30)); + app.failDeployment(stagingTest); + app.submit(appPackage); + + app.runJob(systemTest).runJob(stagingTest) // Tests run with combined upgrade. + .runJob(productionUsCentral1) // Combined upgrade stays together. + .runJob(productionUsEast3).runJob(productionUsWest1); + assertEquals(Map.of(), app.deploymentStatus().jobsToRun()); + assertEquals(Change.empty(), app.instance().change()); } @Test diff --git a/document/src/main/java/com/yahoo/document/BucketDistribution.java b/document/src/main/java/com/yahoo/document/BucketDistribution.java index abacd4fdc2f..e963f57e22b 100644 --- a/document/src/main/java/com/yahoo/document/BucketDistribution.java +++ b/document/src/main/java/com/yahoo/document/BucketDistribution.java @@ -1,8 +1,6 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document; -import com.yahoo.document.BucketId; - import java.util.ArrayList; import java.util.List; import java.util.logging.Level; @@ -14,7 +12,7 @@ import java.util.logging.Logger; public class BucketDistribution { // A logger object to enable proper logging. - private static Logger log = Logger.getLogger(BucketDistribution.class.getName()); + private static final Logger log = Logger.getLogger(BucketDistribution.class.getName()); // A map from bucket id to column index. private int[] bucketToColumn; diff --git a/document/src/main/java/com/yahoo/document/DocumentId.java b/document/src/main/java/com/yahoo/document/DocumentId.java index 8c35eaa0329..3512c5cb7b7 100644 --- a/document/src/main/java/com/yahoo/document/DocumentId.java +++ b/document/src/main/java/com/yahoo/document/DocumentId.java @@ -11,6 +11,7 @@ import com.yahoo.vespa.objects.Identifiable; import com.yahoo.vespa.objects.Serializer; import java.io.Serializable; +import java.util.Objects; /** * The id of a document @@ -18,11 +19,9 @@ import java.io.Serializable; public class DocumentId extends Identifiable implements Serializable { private IdString id; - private GlobalId globalId; + private GlobalId globalId = null; - /** - * Constructor used for deserialization. - */ + /** Constructor used for deserialization. */ public DocumentId(Deserializer buf) { deserialize(buf); } @@ -33,21 +32,14 @@ public class DocumentId extends Identifiable implements Serializable { * The document id string can only contain text characters. */ public DocumentId(String id) { - if (id == null) { - throw new IllegalArgumentException("Cannot create DocumentId from null id."); - } - if (id.length() > IdString.MAX_LENGTH) { - throw new IllegalArgumentException("The document id(" + id.length() + ") is too long(" + IdString.MAX_LENGTH + "). " + - "However if you have already fed a document earlier on and want to remove it, you can do so by " + - "calling new DocumentId(IdString.createIdStringLessStrict()) that will bypass this restriction."); - } - this.id = IdString.createIdString(id); - globalId = null; + this.id = IdString.createIdString(Objects.requireNonNull(id)); + if (id.length() > IdString.MAX_LENGTH) + throw new IllegalArgumentException("Document id of length " + id.length() + + " is longer than the max " + IdString.MAX_LENGTH); } public DocumentId(IdString id) { this.id = id; - globalId = null; } /** @@ -86,14 +78,17 @@ public class DocumentId extends Identifiable implements Serializable { return id.toString().compareTo(cmp.id.toString()); } + @Override public boolean equals(Object o) { return o instanceof DocumentId && id.equals(((DocumentId)o).id); } + @Override public int hashCode() { return id.hashCode(); } + @Override public String toString() { return id.toString(); } @@ -107,7 +102,6 @@ public class DocumentId extends Identifiable implements Serializable { } } - public void onDeserialize(Deserializer data) throws DeserializationException { if (data instanceof DocumentReader) { id = ((DocumentReader)data).readDocumentId().getScheme(); @@ -123,4 +117,5 @@ public class DocumentId extends Identifiable implements Serializable { public String getDocType() { return id.getDocType(); } + } diff --git a/document/src/main/java/com/yahoo/document/GlobalId.java b/document/src/main/java/com/yahoo/document/GlobalId.java index 9e90b59171c..b9d454dd007 100644 --- a/document/src/main/java/com/yahoo/document/GlobalId.java +++ b/document/src/main/java/com/yahoo/document/GlobalId.java @@ -40,10 +40,10 @@ public class GlobalId implements Comparable { /** * Constructs a new global id from a document id string. * - * @param id The document id to derive from. + * @param id the document id to derive from */ public GlobalId(IdString id) { - byte [] raw = MD5.md5.get().digest(id.toUtf8().wrap().array()); + byte[] raw = MD5.md5.get().digest(id.toUtf8().wrap().array()); long location = id.getLocation(); this.raw = new byte [LENGTH]; for (int i = 0; i < 4; ++i) { @@ -57,7 +57,7 @@ public class GlobalId implements Comparable { /** * Constructs a global id by deserializing content from the given byte buffer. * - * @param buf The buffer to deserialize from. + * @param buf the buffer to deserialize from */ public GlobalId(Deserializer buf) { raw = buf.getBytes(null, LENGTH); @@ -66,7 +66,7 @@ public class GlobalId implements Comparable { /** * Serializes the content of this global id into the given byte buffer. * - * @param buf The buffer to serialize to. + * @param buf the buffer to serialize to */ public void serialize(Serializer buf) { buf.put(null, raw); diff --git a/document/src/test/java/com/yahoo/document/DocumentIdTestCase.java b/document/src/test/java/com/yahoo/document/DocumentIdTestCase.java index 63a0f8d25ed..0d3b07fd6ee 100644 --- a/document/src/test/java/com/yahoo/document/DocumentIdTestCase.java +++ b/document/src/test/java/com/yahoo/document/DocumentIdTestCase.java @@ -266,8 +266,7 @@ public class DocumentIdTestCase { new DocumentId(sb.toString()); fail("Expected an IllegalArgumentException to be thrown"); } catch (IllegalArgumentException ex) { - assertTrue(ex.getMessage().contains("However if you have already fed a document earlier on and want to remove it, " + - "you can do so by calling new DocumentId(IdString.createIdStringLessStrict()) that will bypass this restriction.")); + assertEquals("Document id length 65548 is longer than max length of 65536", ex.getMessage()); } assertEquals(65548, new DocumentId(IdString.createIdStringLessStrict(sb.toString())).toString().length()); } diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 881b62d1e04..5376fa983af 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -390,6 +390,13 @@ public class Flags { "Takes effect immediately", TENANT_ID); + public static final UnboundBooleanFlag MERGE_GROUPING_RESULT_IN_SEARCH_INVOKER = defineFeatureFlag( + "merge-grouping-result-in-search-invoker", false, + List.of("bjorncs", "baldersheim"), "2022-02-23", "2022-08-01", + "Merge grouping results incrementally in interleaved search invoker", + "Takes effect at redeployment", + APPLICATION_ID); + /** WARNING: public for testing: All flags should be defined in {@link Flags}. */ public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, List<String> owners, String createdAt, String expiresAt, String description, diff --git a/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java index 25b3cb18ff9..c88a567c559 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java @@ -46,6 +46,8 @@ public class Grouping extends Identifiable { // Actual root group, does not require level details. private Group root = new Group(); + private boolean postMergeCompleted = false; + /** * <p>Constructs an empty result node. <b>NOTE:</b> This instance is broken until non-optional member data is * set.</p> @@ -78,7 +80,9 @@ public class Grouping extends Identifiable { * that might have changes due to the merge.</p> */ public void postMerge() { + if (postMergeCompleted) return; root.postMerge(groupingLevels, firstLevel, 0); + postMergeCompleted = true; } /** diff --git a/searchlib/src/tests/rankingexpression/rankingexpressionlist b/searchlib/src/tests/rankingexpression/rankingexpressionlist index def378069a8..c032ffd1a01 100644 --- a/searchlib/src/tests/rankingexpression/rankingexpressionlist +++ b/searchlib/src/tests/rankingexpression/rankingexpressionlist @@ -165,3 +165,5 @@ if(1.09999~=1.1,2,3); if (1.09999 ~= 1.1, 2, 3) 1 && 0 || 1 !a && (a || a) 10 ^ 3 +true +false diff --git a/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h b/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h index 5c243ea4af9..6d8cdc16743 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h +++ b/storage/src/vespa/storage/persistence/filestorage/filestorhandler.h @@ -128,6 +128,7 @@ public: CLOSED }; + FileStorHandler() : _getNextMessageTimout(100ms) { } virtual ~FileStorHandler() = default; @@ -170,7 +171,12 @@ public: * * @param stripe The stripe to get messages for */ - virtual LockedMessage getNextMessage(uint32_t stripeId) = 0; + virtual LockedMessage getNextMessage(uint32_t stripeId, vespalib::steady_time deadline) = 0; + + /** Only used for testing, should be removed */ + LockedMessage getNextMessage(uint32_t stripeId) { + return getNextMessage(stripeId, vespalib::steady_clock::now() + _getNextMessageTimout); + } /** * Lock a bucket. By default, each file stor thread has the locks of all @@ -268,7 +274,7 @@ public: virtual uint32_t getQueueSize() const = 0; // Commands used by testing - virtual void setGetNextMessageTimeout(vespalib::duration timeout) = 0; + void setGetNextMessageTimeout(vespalib::duration timeout) { _getNextMessageTimout = timeout; } virtual std::string dumpQueue() const = 0; @@ -276,7 +282,13 @@ public: virtual vespalib::SharedOperationThrottler& operation_throttler() const noexcept = 0; + virtual void reconfigure_dynamic_throttler(const vespalib::SharedOperationThrottler::DynamicThrottleParams& params) = 0; + + virtual void use_dynamic_operation_throttling(bool use_dynamic) noexcept = 0; + virtual void set_throttle_apply_bucket_diff_ops(bool throttle_apply_bucket_diff) noexcept = 0; +private: + vespalib::duration _getNextMessageTimout; }; } // storage diff --git a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp index b5de5a233cc..c44ae305fa2 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp +++ b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.cpp @@ -5,7 +5,6 @@ #include "mergestatus.h" #include <vespa/storageapi/message/bucketsplitting.h> #include <vespa/storageapi/message/persistence.h> -#include <vespa/storageapi/message/removelocation.h> #include <vespa/storage/bucketdb/storbucketdb.h> #include <vespa/storage/common/bucketmessages.h> #include <vespa/storage/common/statusmessages.h> @@ -40,22 +39,23 @@ uint32_t per_stripe_merge_limit(uint32_t num_threads, uint32_t num_stripes) noex FileStorHandlerImpl::FileStorHandlerImpl(MessageSender& sender, FileStorMetrics& metrics, ServiceLayerComponentRegister& compReg) - : FileStorHandlerImpl(1, 1, sender, metrics, compReg, vespalib::SharedOperationThrottler::make_unlimited_throttler()) + : FileStorHandlerImpl(1, 1, sender, metrics, compReg, vespalib::SharedOperationThrottler::DynamicThrottleParams()) { } FileStorHandlerImpl::FileStorHandlerImpl(uint32_t numThreads, uint32_t numStripes, MessageSender& sender, FileStorMetrics& metrics, ServiceLayerComponentRegister& compReg, - std::unique_ptr<vespalib::SharedOperationThrottler> operation_throttler) + const vespalib::SharedOperationThrottler::DynamicThrottleParams& dyn_throttle_params) : _component(compReg, "filestorhandlerimpl"), _state(FileStorHandler::AVAILABLE), _metrics(nullptr), - _operation_throttler(std::move(operation_throttler)), + _dynamic_operation_throttler(vespalib::SharedOperationThrottler::make_dynamic_throttler(dyn_throttle_params)), + _unlimited_operation_throttler(vespalib::SharedOperationThrottler::make_unlimited_throttler()), + _active_throttler(_unlimited_operation_throttler.get()), // Will be set by FileStorManager _stripes(), _messageSender(sender), _bucketIdFactory(_component.getBucketIdFactory()), - _getNextMessageTimeout(100ms), _max_active_merges_per_stripe(per_stripe_merge_limit(numThreads, numStripes)), _paused(false), _throttle_apply_bucket_diff_ops(false), @@ -251,6 +251,22 @@ FileStorHandlerImpl::schedule_and_get_next_async_message(const std::shared_ptr<a return {}; } +void +FileStorHandlerImpl::reconfigure_dynamic_throttler(const vespalib::SharedOperationThrottler::DynamicThrottleParams& params) +{ + _dynamic_operation_throttler->reconfigure_dynamic_throttling(params); +} + +void +FileStorHandlerImpl::use_dynamic_operation_throttling(bool use_dynamic) noexcept +{ + // Use release semantics instead of relaxed to ensure transitive visibility even in + // non-persistence threads that try to invoke the throttler (i.e. RPC threads). + _active_throttler.store(use_dynamic ? _dynamic_operation_throttler.get() + : _unlimited_operation_throttler.get(), + std::memory_order_release); +} + bool FileStorHandlerImpl::messageMayBeAborted(const api::StorageMessage& msg) { @@ -333,9 +349,9 @@ FileStorHandlerImpl::updateMetrics(const MetricLockGuard &) std::lock_guard lockGuard(_mergeStatesLock); _metrics->pendingMerges.addValue(_mergeStates.size()); _metrics->queueSize.addValue(getQueueSize()); - _metrics->throttle_window_size.addValue(_operation_throttler->current_window_size()); - _metrics->throttle_waiting_threads.addValue(_operation_throttler->waiting_threads()); - _metrics->throttle_active_tokens.addValue(_operation_throttler->current_active_token_count()); + _metrics->throttle_window_size.addValue(operation_throttler().current_window_size()); + _metrics->throttle_waiting_threads.addValue(operation_throttler().waiting_threads()); + _metrics->throttle_active_tokens.addValue(operation_throttler().current_active_token_count()); for (const auto & stripe : _metrics->stripes) { const auto & m = stripe->averageQueueWaitingTime; @@ -377,13 +393,13 @@ FileStorHandlerImpl::makeQueueTimeoutReply(api::StorageMessage& msg) } FileStorHandler::LockedMessage -FileStorHandlerImpl::getNextMessage(uint32_t stripeId) +FileStorHandlerImpl::getNextMessage(uint32_t stripeId, vespalib::steady_time deadline) { if (!tryHandlePause()) { return {}; // Still paused, return to allow tick. } - return getNextMessage(stripeId, _getNextMessageTimeout); + return _stripes[stripeId].getNextMessage(deadline); } std::shared_ptr<FileStorHandler::BucketLockInterface> @@ -919,7 +935,7 @@ FileStorHandlerImpl::Stripe::operation_type_should_be_throttled(api::MessageType } FileStorHandler::LockedMessage -FileStorHandlerImpl::Stripe::getNextMessage(vespalib::duration timeout) +FileStorHandlerImpl::Stripe::getNextMessage(vespalib::steady_time deadline) { std::unique_lock guard(*_lock); ThrottleToken throttle_token; @@ -955,12 +971,12 @@ FileStorHandlerImpl::Stripe::getNextMessage(vespalib::duration timeout) // Depending on whether we were blocked due to no usable ops in queue or throttling, // wait for either the queue or throttler to (hopefully) have some fresh stuff for us. if (!was_throttled) { - _cond->wait_for(guard, timeout); + _cond->wait_until(guard, deadline); } else { // Have to release lock before doing a blocking throttle token fetch, since it // prevents RPC threads from pushing onto the queue. guard.unlock(); - throttle_token = _owner.operation_throttler().blocking_acquire_one(timeout); + throttle_token = _owner.operation_throttler().blocking_acquire_one(deadline); guard.lock(); if (!throttle_token.valid()) { _metrics->timeouts_waiting_for_throttle_token.inc(); diff --git a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h index 1bc0ab87b1c..dbef1d06dad 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h +++ b/storage/src/vespa/storage/persistence/filestorage/filestorhandlerimpl.h @@ -132,7 +132,7 @@ public: std::shared_ptr<FileStorHandler::BucketLockInterface> lock(const document::Bucket & bucket, api::LockingRequirements lockReq); void failOperations(const document::Bucket & bucket, const api::ReturnCode & code); - FileStorHandler::LockedMessage getNextMessage(vespalib::duration timeout); + FileStorHandler::LockedMessage getNextMessage(vespalib::steady_time deadline); void dumpQueue(std::ostream & os) const; void dumpActiveHtml(std::ostream & os) const; void dumpQueueHtml(std::ostream & os) const; @@ -192,10 +192,10 @@ public: FileStorHandlerImpl(MessageSender& sender, FileStorMetrics& metrics, ServiceLayerComponentRegister& compReg); FileStorHandlerImpl(uint32_t numThreads, uint32_t numStripes, MessageSender&, FileStorMetrics&, - ServiceLayerComponentRegister&, std::unique_ptr<vespalib::SharedOperationThrottler>); + ServiceLayerComponentRegister&, + const vespalib::SharedOperationThrottler::DynamicThrottleParams& dyn_throttle_params); ~FileStorHandlerImpl() override; - void setGetNextMessageTimeout(vespalib::duration timeout) override { _getNextMessageTimeout = timeout; } void flush(bool killPendingMerges) override; void setDiskState(DiskState state) override; @@ -204,7 +204,7 @@ public: bool schedule(const std::shared_ptr<api::StorageMessage>&) override; ScheduleAsyncResult schedule_and_get_next_async_message(const std::shared_ptr<api::StorageMessage>& msg) override; - FileStorHandler::LockedMessage getNextMessage(uint32_t stripeId) override; + FileStorHandler::LockedMessage getNextMessage(uint32_t stripeId, vespalib::steady_time deadline) override; void remapQueueAfterJoin(const RemapInfo& source, RemapInfo& target) override; void remapQueueAfterSplit(const RemapInfo& source, RemapInfo& target1, RemapInfo& target2) override; @@ -245,9 +245,18 @@ public: void abortQueuedOperations(const AbortBucketOperationsCommand& cmd) override; vespalib::SharedOperationThrottler& operation_throttler() const noexcept override { - return *_operation_throttler; + // It would be reasonable to assume that this could be a relaxed load since the set + // of possible throttlers is static and all _persistence_ thread creation is sequenced + // after throttler creation. But since the throttler may be invoked by RPC threads + // created in another context, use acquire semantics to ensure transitive visibility. + // TODO remove need for atomics once the throttler testing dust settles + return *_active_throttler.load(std::memory_order_acquire); } + void reconfigure_dynamic_throttler(const vespalib::SharedOperationThrottler::DynamicThrottleParams& params) override; + + void use_dynamic_operation_throttling(bool use_dynamic) noexcept override; + void set_throttle_apply_bucket_diff_ops(bool throttle_apply_bucket_diff) noexcept override { // Relaxed is fine, worst case from temporarily observing a stale value is that // an ApplyBucketDiff message is (or isn't) throttled at a high level. @@ -264,13 +273,14 @@ private: ServiceLayerComponent _component; std::atomic<DiskState> _state; FileStorDiskMetrics * _metrics; - std::unique_ptr<vespalib::SharedOperationThrottler> _operation_throttler; + std::unique_ptr<vespalib::SharedOperationThrottler> _dynamic_operation_throttler; + std::unique_ptr<vespalib::SharedOperationThrottler> _unlimited_operation_throttler; + std::atomic<vespalib::SharedOperationThrottler*> _active_throttler; std::vector<Stripe> _stripes; MessageSender& _messageSender; const document::BucketIdFactory& _bucketIdFactory; mutable std::mutex _mergeStatesLock; std::map<document::Bucket, std::shared_ptr<MergeStatus>> _mergeStates; - vespalib::duration _getNextMessageTimeout; const uint32_t _max_active_merges_per_stripe; // Read concurrently by stripes. mutable std::mutex _pauseMonitor; mutable std::condition_variable _pauseCond; @@ -355,9 +365,6 @@ private: Stripe & stripe(const document::Bucket & bucket) { return _stripes[stripe_index(bucket)]; } - FileStorHandler::LockedMessage getNextMessage(uint32_t stripeId, vespalib::duration timeout) { - return _stripes[stripeId].getNextMessage(timeout); - } ActiveOperationsStats get_active_operations_stats(bool reset_min_max) const override; }; diff --git a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp index f4aff96b53c..09bd842c308 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp +++ b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp @@ -164,20 +164,6 @@ dynamic_throttle_params_from_config(const StorFilestorConfig& config, uint32_t n return params; } -std::unique_ptr<vespalib::SharedOperationThrottler> -make_operation_throttler_from_config(const StorFilestorConfig& config, uint32_t num_threads) -{ - // TODO only use struct config field instead once config model is updated - const bool use_dynamic_throttling = ((config.asyncOperationThrottlerType == StorFilestorConfig::AsyncOperationThrottlerType::DYNAMIC) || - (config.asyncOperationThrottler.type == StorFilestorConfig::AsyncOperationThrottler::Type::DYNAMIC)); - if (use_dynamic_throttling) { - auto dyn_params = dynamic_throttle_params_from_config(config, num_threads); - return vespalib::SharedOperationThrottler::make_dynamic_throttler(dyn_params); - } else { - return vespalib::SharedOperationThrottler::make_unlimited_throttler(); - } -} - #ifdef __PIC__ #define TLS_LINKAGE __attribute__((visibility("hidden"), tls_model("initial-exec"))) #else @@ -216,18 +202,17 @@ FileStorManager::getThreadLocalHandler() { } return *_G_threadLocalHandler; } -/** - * If live configuration, assuming storageserver makes sure no messages are - * incoming during reconfiguration - */ + void FileStorManager::configure(std::unique_ptr<StorFilestorConfig> config) { // If true, this is not the first configure. - bool liveUpdate = ! _threads.empty(); + const bool liveUpdate = ! _threads.empty(); _use_async_message_handling_on_schedule = config->useAsyncMessageHandlingOnSchedule; _host_info_reporter.set_noise_level(config->resourceUsageReporterNoiseLevel); + const bool use_dynamic_throttling = ((config->asyncOperationThrottlerType == StorFilestorConfig::AsyncOperationThrottlerType::DYNAMIC) || + (config->asyncOperationThrottler.type == StorFilestorConfig::AsyncOperationThrottler::Type::DYNAMIC)); const bool throttle_merge_feed_ops = config->asyncOperationThrottler.throttleIndividualMergeFeedOps; if (!liveUpdate) { @@ -235,10 +220,10 @@ FileStorManager::configure(std::unique_ptr<StorFilestorConfig> config) uint32_t numThreads = std::max(1, _config->numThreads); uint32_t numStripes = std::max(1u, numThreads / 2); _metrics->initDiskMetrics(numStripes, computeAllPossibleHandlerThreads(*_config)); - auto operation_throttler = make_operation_throttler_from_config(*_config, numThreads); + auto dyn_params = dynamic_throttle_params_from_config(*_config, numThreads); _filestorHandler = std::make_unique<FileStorHandlerImpl>(numThreads, numStripes, *this, *_metrics, - _compReg, std::move(operation_throttler)); + _compReg, dyn_params); uint32_t numResponseThreads = computeNumResponseThreads(_config->numResponseThreads); _sequencedExecutor = vespalib::SequencedTaskExecutor::create(CpuUsage::wrap(response_executor, CpuUsage::Category::WRITE), numResponseThreads, 10000, @@ -253,10 +238,11 @@ FileStorManager::configure(std::unique_ptr<StorFilestorConfig> config) } else { assert(_filestorHandler); auto updated_dyn_throttle_params = dynamic_throttle_params_from_config(*config, _threads.size()); - _filestorHandler->operation_throttler().reconfigure_dynamic_throttling(updated_dyn_throttle_params); + _filestorHandler->reconfigure_dynamic_throttler(updated_dyn_throttle_params); } // TODO remove once desired dynamic throttling behavior is set in stone { + _filestorHandler->use_dynamic_operation_throttling(use_dynamic_throttling); _filestorHandler->set_throttle_apply_bucket_diff_ops(!throttle_merge_feed_ops); std::lock_guard guard(_lock); for (auto& ph : _persistenceHandlers) { diff --git a/storage/src/vespa/storage/persistence/mergehandler.cpp b/storage/src/vespa/storage/persistence/mergehandler.cpp index 78f53de46b3..8287fe27509 100644 --- a/storage/src/vespa/storage/persistence/mergehandler.cpp +++ b/storage/src/vespa/storage/persistence/mergehandler.cpp @@ -32,7 +32,6 @@ MergeHandler::MergeHandler(PersistenceUtil& env, spi::PersistenceProvider& spi, _cluster_context(cluster_context), _env(env), _spi(spi), - _operation_throttler(_env._fileStorHandler.operation_throttler()), _monitored_ref_count(std::make_unique<MonitoredRefCount>()), _maxChunkSize(maxChunkSize), _commonMergeChainOptimalizationMinimumSize(commonMergeChainOptimalizationMinimumSize), @@ -515,7 +514,7 @@ MergeHandler::applyDiffEntry(std::shared_ptr<ApplyBucketDiffState> async_results spi::Context& context, const document::DocumentTypeRepo& repo) const { - auto throttle_token = throttle_merge_feed_ops() ? _operation_throttler.blocking_acquire_one() + auto throttle_token = throttle_merge_feed_ops() ? _env._fileStorHandler.operation_throttler().blocking_acquire_one() : vespalib::SharedOperationThrottler::Token(); spi::Timestamp timestamp(e._entry._timestamp); if (!(e._entry._flags & (DELETED | DELETED_IN_PLACE))) { diff --git a/storage/src/vespa/storage/persistence/mergehandler.h b/storage/src/vespa/storage/persistence/mergehandler.h index 93fb7efc8d0..1ed2fa878bc 100644 --- a/storage/src/vespa/storage/persistence/mergehandler.h +++ b/storage/src/vespa/storage/persistence/mergehandler.h @@ -24,7 +24,6 @@ namespace vespalib { class ISequencedTaskExecutor; -class SharedOperationThrottler; } namespace storage { @@ -96,7 +95,6 @@ private: const ClusterContext &_cluster_context; PersistenceUtil &_env; spi::PersistenceProvider &_spi; - vespalib::SharedOperationThrottler& _operation_throttler; std::unique_ptr<vespalib::MonitoredRefCount> _monitored_ref_count; const uint32_t _maxChunkSize; const uint32_t _commonMergeChainOptimalizationMinimumSize; diff --git a/storage/src/vespa/storage/persistence/persistencethread.cpp b/storage/src/vespa/storage/persistence/persistencethread.cpp index b89c60d4720..8e1fdb06ded 100644 --- a/storage/src/vespa/storage/persistence/persistencethread.cpp +++ b/storage/src/vespa/storage/persistence/persistencethread.cpp @@ -33,10 +33,13 @@ PersistenceThread::run(framework::ThreadHandle& thread) { LOG(debug, "Started persistence thread"); + vespalib::duration max_wait_time = vespalib::adjustTimeoutByDetectedHz(100ms); while (!thread.interrupted()) { - thread.registerTick(); + vespalib::steady_time now = vespalib::steady_clock::now(); + thread.registerTick(framework::UNKNOWN_CYCLE, now); - FileStorHandler::LockedMessage lock(_fileStorHandler.getNextMessage(_stripeId)); + vespalib::steady_time deadline = now + max_wait_time; + FileStorHandler::LockedMessage lock(_fileStorHandler.getNextMessage(_stripeId, deadline)); if (lock.lock) { _persistenceHandler.processLockedMessage(std::move(lock)); diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java index 10e55527a2a..7b3e488a5a5 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java @@ -214,7 +214,7 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler { this.operations = new ConcurrentLinkedDeque<>(); long resendDelayMS = SystemTimer.adjustTimeoutByDetectedHz(Duration.ofMillis(executorConfig.resendDelayMillis())).toMillis(); - //TODO Here it would be better do have dedicated threads with different wait depending on blocked or empty. + // TODO: Here it would be better do have dedicated threads with different wait depending on blocked or empty. this.dispatcher.scheduleWithFixedDelay(this::dispatchEnqueued, resendDelayMS, resendDelayMS, MILLISECONDS); this.visitDispatcher.scheduleWithFixedDelay(this::dispatchVisitEnqueued, resendDelayMS, resendDelayMS, MILLISECONDS); } diff --git a/vespalib/src/tests/shared_operation_throttler/shared_operation_throttler_test.cpp b/vespalib/src/tests/shared_operation_throttler/shared_operation_throttler_test.cpp index eefc0ca72c0..d6946905236 100644 --- a/vespalib/src/tests/shared_operation_throttler/shared_operation_throttler_test.cpp +++ b/vespalib/src/tests/shared_operation_throttler/shared_operation_throttler_test.cpp @@ -4,6 +4,8 @@ #include <vespa/vespalib/util/barrier.h> #include <thread> +using vespalib::steady_clock; + namespace vespalib { using ThrottleToken = SharedOperationThrottler::Token; @@ -47,7 +49,7 @@ TEST_F("blocking acquire returns immediately if slot available", DynamicThrottle auto token = f1._throttler->blocking_acquire_one(); EXPECT_TRUE(token.valid()); token.reset(); - token = f1._throttler->blocking_acquire_one(600s); // Should never block. + token = f1._throttler->blocking_acquire_one(steady_clock::now() + 600s); // Should never block. EXPECT_TRUE(token.valid()); } @@ -70,11 +72,11 @@ TEST_F("blocking call woken up if throttle slot available", DynamicThrottleFixtu TEST_F("time-bounded blocking acquire waits for timeout", DynamicThrottleFixture()) { auto window_filling_token = f1._throttler->try_acquire_one(); - auto before = std::chrono::steady_clock::now(); + auto before = steady_clock::now(); // Will block for at least 1ms. Since no window slot will be available by that time, // an invalid token should be returned. - auto token = f1._throttler->blocking_acquire_one(1ms); - auto after = std::chrono::steady_clock::now(); + auto token = f1._throttler->blocking_acquire_one(before + 1ms); + auto after = steady_clock::now(); EXPECT_TRUE((after - before) >= 1ms); EXPECT_FALSE(token.valid()); } diff --git a/vespalib/src/tests/util/generationhandler/CMakeLists.txt b/vespalib/src/tests/util/generationhandler/CMakeLists.txt index 677d5caa0e6..fdf54c59854 100644 --- a/vespalib/src/tests/util/generationhandler/CMakeLists.txt +++ b/vespalib/src/tests/util/generationhandler/CMakeLists.txt @@ -4,5 +4,6 @@ vespa_add_executable(vespalib_generationhandler_test_app TEST generationhandler_test.cpp DEPENDS vespalib + GTest::GTest ) vespa_add_test(NAME vespalib_generationhandler_test_app COMMAND vespalib_generationhandler_test_app) diff --git a/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp b/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp index f269fe729fa..00da752a749 100644 --- a/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp +++ b/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp @@ -1,157 +1,137 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/vespalib/testkit/testapp.h> +#include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/util/generationhandler.h> #include <deque> namespace vespalib { -typedef GenerationHandler::Guard GenGuard; +using GenGuard = GenerationHandler::Guard; -class Test : public vespalib::TestApp { -private: - void requireThatGenerationCanBeIncreased(); - void requireThatReadersCanTakeGuards(); - void requireThatGuardsCanBeCopied(); - void requireThatTheFirstUsedGenerationIsCorrect(); - void requireThatGenerationCanGrowLarge(); -public: - int Main() override; +class GenerationHandlerTest : public ::testing::Test { +protected: + GenerationHandler gh; + GenerationHandlerTest(); + ~GenerationHandlerTest() override; }; -void -Test::requireThatGenerationCanBeIncreased() +GenerationHandlerTest::GenerationHandlerTest() + : ::testing::Test(), + gh() { - GenerationHandler gh; - EXPECT_EQUAL(0u, gh.getCurrentGeneration()); - EXPECT_EQUAL(0u, gh.getFirstUsedGeneration()); +} + +GenerationHandlerTest::~GenerationHandlerTest() = default; + +TEST_F(GenerationHandlerTest, require_that_generation_can_be_increased) +{ + EXPECT_EQ(0u, gh.getCurrentGeneration()); + EXPECT_EQ(0u, gh.getFirstUsedGeneration()); gh.incGeneration(); - EXPECT_EQUAL(1u, gh.getCurrentGeneration()); - EXPECT_EQUAL(1u, gh.getFirstUsedGeneration()); + EXPECT_EQ(1u, gh.getCurrentGeneration()); + EXPECT_EQ(1u, gh.getFirstUsedGeneration()); } -void -Test::requireThatReadersCanTakeGuards() +TEST_F(GenerationHandlerTest, require_that_readers_can_take_guards) { - GenerationHandler gh; - EXPECT_EQUAL(0u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(0)); { GenGuard g1 = gh.takeGuard(); - EXPECT_EQUAL(1u, gh.getGenerationRefCount(0)); + EXPECT_EQ(1u, gh.getGenerationRefCount(0)); { GenGuard g2 = gh.takeGuard(); - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); gh.incGeneration(); { GenGuard g3 = gh.takeGuard(); - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(1u, gh.getGenerationRefCount(1)); - EXPECT_EQUAL(3u, gh.getGenerationRefCount()); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(1u, gh.getGenerationRefCount(1)); + EXPECT_EQ(3u, gh.getGenerationRefCount()); } - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); gh.incGeneration(); { GenGuard g3 = gh.takeGuard(); - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); - EXPECT_EQUAL(1u, gh.getGenerationRefCount(2)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(1u, gh.getGenerationRefCount(2)); } - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(2)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(0u, gh.getGenerationRefCount(2)); } - EXPECT_EQUAL(1u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(2)); + EXPECT_EQ(1u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(0u, gh.getGenerationRefCount(2)); } - EXPECT_EQUAL(0u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(2)); + EXPECT_EQ(0u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(0u, gh.getGenerationRefCount(2)); } -void -Test::requireThatGuardsCanBeCopied() +TEST_F(GenerationHandlerTest, require_that_guards_can_be_copied) { - GenerationHandler gh; GenGuard g1 = gh.takeGuard(); - EXPECT_EQUAL(1u, gh.getGenerationRefCount(0)); + EXPECT_EQ(1u, gh.getGenerationRefCount(0)); GenGuard g2(g1); - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); gh.incGeneration(); GenGuard g3 = gh.takeGuard(); - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(1u, gh.getGenerationRefCount(1)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(1u, gh.getGenerationRefCount(1)); g3 = g2; - EXPECT_EQUAL(3u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(3u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); } -void -Test::requireThatTheFirstUsedGenerationIsCorrect() +TEST_F(GenerationHandlerTest, require_that_the_first_used_generation_is_correct) { - GenerationHandler gh; - EXPECT_EQUAL(0u, gh.getFirstUsedGeneration()); + EXPECT_EQ(0u, gh.getFirstUsedGeneration()); gh.incGeneration(); - EXPECT_EQUAL(1u, gh.getFirstUsedGeneration()); + EXPECT_EQ(1u, gh.getFirstUsedGeneration()); { GenGuard g1 = gh.takeGuard(); gh.incGeneration(); - EXPECT_EQUAL(1u, gh.getGenerationRefCount()); - EXPECT_EQUAL(1u, gh.getFirstUsedGeneration()); + EXPECT_EQ(1u, gh.getGenerationRefCount()); + EXPECT_EQ(1u, gh.getFirstUsedGeneration()); } - EXPECT_EQUAL(1u, gh.getFirstUsedGeneration()); + EXPECT_EQ(1u, gh.getFirstUsedGeneration()); gh.updateFirstUsedGeneration(); // Only writer should call this - EXPECT_EQUAL(0u, gh.getGenerationRefCount()); - EXPECT_EQUAL(2u, gh.getFirstUsedGeneration()); + EXPECT_EQ(0u, gh.getGenerationRefCount()); + EXPECT_EQ(2u, gh.getFirstUsedGeneration()); { GenGuard g1 = gh.takeGuard(); gh.incGeneration(); gh.incGeneration(); - EXPECT_EQUAL(1u, gh.getGenerationRefCount()); - EXPECT_EQUAL(2u, gh.getFirstUsedGeneration()); + EXPECT_EQ(1u, gh.getGenerationRefCount()); + EXPECT_EQ(2u, gh.getFirstUsedGeneration()); { GenGuard g2 = gh.takeGuard(); - EXPECT_EQUAL(2u, gh.getFirstUsedGeneration()); + EXPECT_EQ(2u, gh.getFirstUsedGeneration()); } } - EXPECT_EQUAL(2u, gh.getFirstUsedGeneration()); + EXPECT_EQ(2u, gh.getFirstUsedGeneration()); gh.updateFirstUsedGeneration(); // Only writer should call this - EXPECT_EQUAL(0u, gh.getGenerationRefCount()); - EXPECT_EQUAL(4u, gh.getFirstUsedGeneration()); + EXPECT_EQ(0u, gh.getGenerationRefCount()); + EXPECT_EQ(4u, gh.getFirstUsedGeneration()); } -void -Test::requireThatGenerationCanGrowLarge() +TEST_F(GenerationHandlerTest, require_that_generation_can_grow_large) { - GenerationHandler gh; std::deque<GenGuard> guards; for (size_t i = 0; i < 10000; ++i) { - EXPECT_EQUAL(i, gh.getCurrentGeneration()); + EXPECT_EQ(i, gh.getCurrentGeneration()); guards.push_back(gh.takeGuard()); // take guard on current generation if (i >= 128) { - EXPECT_EQUAL(i - 128, gh.getFirstUsedGeneration()); + EXPECT_EQ(i - 128, gh.getFirstUsedGeneration()); guards.pop_front(); - EXPECT_EQUAL(128u, gh.getGenerationRefCount()); + EXPECT_EQ(128u, gh.getGenerationRefCount()); } gh.incGeneration(); } } -int -Test::Main() -{ - TEST_INIT("generationhandler_test"); - - TEST_DO(requireThatGenerationCanBeIncreased()); - TEST_DO(requireThatReadersCanTakeGuards()); - TEST_DO(requireThatGuardsCanBeCopied()); - TEST_DO(requireThatTheFirstUsedGenerationIsCorrect()); - TEST_DO(requireThatGenerationCanGrowLarge()); - - TEST_DONE(); -} - } -TEST_APPHOOK(vespalib::Test); +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt b/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt index 569db489e3c..7e5a5af79b5 100644 --- a/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt +++ b/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt @@ -4,5 +4,6 @@ vespa_add_executable(vespalib_generation_handler_stress_test_app generation_handler_stress_test.cpp DEPENDS vespalib + GTest::GTest ) -vespa_add_test(NAME vespalib_generation_handler_stress_test_app COMMAND vespalib_generation_handler_stress_test_app BENCHMARK) +vespa_add_test(NAME vespalib_generation_handler_stress_test_app COMMAND vespalib_generation_handler_stress_test_app --smoke-test) diff --git a/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp b/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp index fa2c525b518..0689909da09 100644 --- a/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp +++ b/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/log/log.h> LOG_SETUP("generation_handler_stress_test"); -#include <vespa/vespalib/testkit/testapp.h> +#include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/util/generationhandler.h> #include <vespa/vespalib/util/threadstackexecutor.h> #include <vespa/vespalib/util/size_literals.h> @@ -10,6 +10,12 @@ using vespalib::Executor; using vespalib::GenerationHandler; using vespalib::ThreadStackExecutor; +namespace { + +bool smoke_test = false; +const vespalib::string smoke_test_option = "--smoke-test"; + +} struct WorkContext { @@ -21,26 +27,28 @@ struct WorkContext } }; -struct Fixture { +class Fixture : public ::testing::Test { +protected: GenerationHandler _generationHandler; uint32_t _readThreads; ThreadStackExecutor _writer; // 1 write thread - ThreadStackExecutor _readers; // multiple reader threads + std::unique_ptr<ThreadStackExecutor> _readers; // multiple reader threads std::atomic<long> _readSeed; std::atomic<long> _doneWriteWork; std::atomic<long> _doneReadWork; std::atomic<int> _stopRead; bool _reportWork; - Fixture(uint32_t readThreads = 1); - + Fixture(); ~Fixture(); - void readWork(const WorkContext &context); - void writeWork(uint32_t cnt, WorkContext &context); + void set_read_threads(uint32_t read_threads); + uint32_t getReadThreads() const { return _readThreads; } void stressTest(uint32_t writeCnt); - +public: + void readWork(const WorkContext &context); + void writeWork(uint32_t cnt, WorkContext &context); private: Fixture(const Fixture &index) = delete; Fixture(Fixture &&index) = delete; @@ -49,23 +57,27 @@ private: }; -Fixture::Fixture(uint32_t readThreads) - : _generationHandler(), - _readThreads(readThreads), +Fixture::Fixture() + : ::testing::Test(), + _generationHandler(), + _readThreads(1), _writer(1, 128_Ki), - _readers(readThreads, 128_Ki), + _readers(), _doneWriteWork(0), _doneReadWork(0), _stopRead(0), _reportWork(false) { + set_read_threads(1); } Fixture::~Fixture() { - _readers.sync(); - _readers.shutdown(); + if (_readers) { + _readers->sync(); + _readers->shutdown(); + } _writer.sync(); _writer.shutdown(); if (_reportWork) { @@ -75,6 +87,16 @@ Fixture::~Fixture() } } +void +Fixture::set_read_threads(uint32_t read_threads) +{ + if (_readers) { + _readers->sync(); + _readers->shutdown(); + } + _readThreads = read_threads; + _readers = std::make_unique<ThreadStackExecutor>(read_threads, 128_Ki); +} void Fixture::readWork(const WorkContext &context) @@ -85,7 +107,7 @@ Fixture::readWork(const WorkContext &context) for (i = 0; i < cnt && _stopRead.load() == 0; ++i) { auto guard = _generationHandler.takeGuard(); auto generation = context._generation.load(std::memory_order_relaxed); - EXPECT_GREATER_EQUAL(generation, guard.getGeneration()); + EXPECT_GE(generation, guard.getGeneration()); } _doneReadWork += i; LOG(info, "done %u read work", i); @@ -150,19 +172,32 @@ Fixture::stressTest(uint32_t writeCnt) auto context = std::make_shared<WorkContext>(); _writer.execute(std::make_unique<WriteWorkTask>(*this, writeCnt, context)); for (uint32_t i = 0; i < readThreads; ++i) { - _readers.execute(std::make_unique<ReadWorkTask>(*this, context)); + _readers->execute(std::make_unique<ReadWorkTask>(*this, context)); } + _writer.sync(); + _readers->sync(); } +using GenerationHandlerStressTest = Fixture; -TEST_F("stress test, 2 readers", Fixture(2)) +TEST_F(GenerationHandlerStressTest, stress_test_2_readers) { - f.stressTest(1000000); + set_read_threads(2); + stressTest(smoke_test ? 10000 : 1000000); } -TEST_F("stress test, 4 readers", Fixture(4)) +TEST_F(GenerationHandlerStressTest, stress_test_4_readers) { - f.stressTest(1000000); + set_read_threads(4); + stressTest(smoke_test ? 10000 : 1000000); } -TEST_MAIN() { TEST_RUN_ALL(); } +int main(int argc, char **argv) { + if (argc > 1 && argv[1] == smoke_test_option) { + smoke_test = true; + ++argv; + --argc; + } + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/vespalib/src/vespa/vespalib/util/shared_operation_throttler.cpp b/vespalib/src/vespa/vespalib/util/shared_operation_throttler.cpp index 6e273d1a7ea..478d8c1b9e9 100644 --- a/vespalib/src/vespa/vespalib/util/shared_operation_throttler.cpp +++ b/vespalib/src/vespa/vespalib/util/shared_operation_throttler.cpp @@ -24,7 +24,7 @@ public: internal_ref_count_increase(); return Token(this, TokenCtorTag{}); } - Token blocking_acquire_one(vespalib::duration) noexcept override { + Token blocking_acquire_one(vespalib::steady_time) noexcept override { internal_ref_count_increase(); return Token(this, TokenCtorTag{}); } @@ -267,7 +267,7 @@ public: ~DynamicOperationThrottler() override; Token blocking_acquire_one() noexcept override; - Token blocking_acquire_one(vespalib::duration timeout) noexcept override; + Token blocking_acquire_one(vespalib::steady_time deadline) noexcept override; Token try_acquire_one() noexcept override; uint32_t current_window_size() const noexcept override; uint32_t current_active_token_count() const noexcept override; @@ -334,12 +334,12 @@ DynamicOperationThrottler::blocking_acquire_one() noexcept } DynamicOperationThrottler::Token -DynamicOperationThrottler::blocking_acquire_one(vespalib::duration timeout) noexcept +DynamicOperationThrottler::blocking_acquire_one(vespalib::steady_time deadline) noexcept { std::unique_lock lock(_mutex); if (!has_spare_capacity_in_active_window()) { ++_waiting_threads; - const bool accepted = _cond.wait_for(lock, timeout, [&] { + const bool accepted = _cond.wait_until(lock, deadline, [&] { return has_spare_capacity_in_active_window(); }); --_waiting_threads; diff --git a/vespalib/src/vespa/vespalib/util/shared_operation_throttler.h b/vespalib/src/vespa/vespalib/util/shared_operation_throttler.h index b7913029c1e..95d6d361cb6 100644 --- a/vespalib/src/vespa/vespalib/util/shared_operation_throttler.h +++ b/vespalib/src/vespa/vespalib/util/shared_operation_throttler.h @@ -54,9 +54,9 @@ public: // Acquire a valid throttling token, uninterruptedly blocking until one can be obtained. [[nodiscard]] virtual Token blocking_acquire_one() noexcept = 0; // Attempt to acquire a valid throttling token, waiting up to `timeout` for one to be - // available. If the timeout is exceeded without any tokens becoming available, an + // available. If the deadline is reached without any tokens becoming available, an // invalid token will be returned. - [[nodiscard]] virtual Token blocking_acquire_one(vespalib::duration timeout) noexcept = 0; + [[nodiscard]] virtual Token blocking_acquire_one(vespalib::steady_time deadline) noexcept = 0; // Attempt to acquire a valid throttling token if one is immediately available. // An invalid token will be returned if none is available. Never blocks (other than // when contending for the internal throttler mutex). |