diff options
18 files changed, 774 insertions, 578 deletions
diff --git a/client/go/vespa/target.go b/client/go/vespa/target.go index 2314cc06f33..8d48f1ffa3a 100644 --- a/client/go/vespa/target.go +++ b/client/go/vespa/target.go @@ -3,23 +3,15 @@ package vespa import ( - "bytes" "crypto/tls" - "encoding/json" "fmt" "io" "io/ioutil" - "math" "net/http" - "net/url" - "sort" - "strconv" "time" - "github.com/vespa-engine/vespa/client/go/auth0" "github.com/vespa-engine/vespa/client/go/util" "github.com/vespa-engine/vespa/client/go/version" - "github.com/vespa-engine/vespa/client/go/zts" ) const ( @@ -94,26 +86,6 @@ type LogOptions struct { Level int } -// CloudOptions configures URL and authentication for a cloud target. -type APIOptions struct { - System System - TLSOptions TLSOptions - APIKey []byte - AuthConfigPath string -} - -// CloudDeploymentOptions configures the deployment to manage through a cloud target. -type CloudDeploymentOptions struct { - Deployment Deployment - TLSOptions TLSOptions - ClusterURLs map[string]string // Endpoints keyed on cluster name -} - -type customTarget struct { - targetType string - baseURL string -} - // Do sends request to this service. Any required authentication happens automatically. func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Response, error) { if s.TLSOptions.KeyPair.Certificate != nil { @@ -143,12 +115,7 @@ func (s *Service) Wait(timeout time.Duration) (int, error) { default: return 0, fmt.Errorf("invalid service: %s", s.Name) } - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return 0, err - } - okFunc := func(status int, response []byte) (bool, error) { return status/100 == 2, nil } - return wait(okFunc, func() *http.Request { return req }, &s.TLSOptions.KeyPair, timeout) + return waitForOK(url, &s.TLSOptions.KeyPair, timeout) } func (s *Service) Description() string { @@ -163,442 +130,23 @@ func (s *Service) Description() string { return fmt.Sprintf("No description of service %s", s.Name) } -func (t *customTarget) Type() string { return t.targetType } - -func (t *customTarget) Deployment() Deployment { return Deployment{} } - -func (t *customTarget) Service(name string, timeout time.Duration, sessionOrRunID int64, cluster string) (*Service, error) { - if timeout > 0 && name != DeployService { - if err := t.waitForConvergence(timeout); err != nil { - return nil, err - } - } - switch name { - case DeployService, QueryService, DocumentService: - url, err := t.urlWithPort(name) - if err != nil { - return nil, err - } - return &Service{BaseURL: url, Name: name}, nil - } - return nil, fmt.Errorf("unknown service: %s", name) -} - -func (t *customTarget) PrintLog(options LogOptions) error { - return fmt.Errorf("reading logs from non-cloud deployment is unsupported") -} - -func (t *customTarget) SignRequest(req *http.Request, sigKeyId string) error { return nil } +func isOK(status int) bool { return status/100 == 2 } -func (t *customTarget) CheckVersion(version version.Version) error { return nil } +type responseFunc func(status int, response []byte) (bool, error) -func (t *customTarget) urlWithPort(serviceName string) (string, error) { - u, err := url.Parse(t.baseURL) - if err != nil { - return "", err - } - port := u.Port() - if port == "" { - switch serviceName { - case DeployService: - port = "19071" - case QueryService, DocumentService: - port = "8080" - default: - return "", fmt.Errorf("unknown service: %s", serviceName) - } - u.Host = u.Host + ":" + port - } - return u.String(), nil -} +type requestFunc func() *http.Request -func (t *customTarget) waitForConvergence(timeout time.Duration) error { - deployer, err := t.Service(DeployService, 0, 0, "") - if err != nil { - return err - } - url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployer.BaseURL) +// waitForOK queries url and returns its status code. If the url returns a non-200 status code, it is repeatedly queried +// until timeout elapses. +func waitForOK(url string, certificate *tls.Certificate, timeout time.Duration) (int, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { - return err - } - converged := false - convergedFunc := func(status int, response []byte) (bool, error) { - if status/100 != 2 { - return false, nil - } - var resp serviceConvergeResponse - if err := json.Unmarshal(response, &resp); err != nil { - return false, nil - } - converged = resp.Converged - return converged, nil - } - if _, err := wait(convergedFunc, func() *http.Request { return req }, nil, timeout); err != nil { - return err - } - if !converged { - return fmt.Errorf("services have not converged") - } - return nil -} - -type cloudTarget struct { - apiOptions APIOptions - deploymentOptions CloudDeploymentOptions - logOptions LogOptions - ztsClient ztsClient -} - -type ztsClient interface { - AccessToken(domain string, certficiate tls.Certificate) (string, error) -} - -func (t *cloudTarget) resolveEndpoint(cluster string) (string, error) { - if cluster == "" { - for _, u := range t.deploymentOptions.ClusterURLs { - if len(t.deploymentOptions.ClusterURLs) == 1 { - return u, nil - } else { - return "", fmt.Errorf("multiple clusters, none chosen: %v", t.deploymentOptions.ClusterURLs) - } - } - } else { - u := t.deploymentOptions.ClusterURLs[cluster] - if u == "" { - clusters := make([]string, len(t.deploymentOptions.ClusterURLs)) - for c := range t.deploymentOptions.ClusterURLs { - clusters = append(clusters, c) - } - return "", fmt.Errorf("unknown cluster '%s': must be one of %v", cluster, clusters) - } - return u, nil - } - - return "", fmt.Errorf("no endpoints") -} - -func (t *cloudTarget) Type() string { - switch t.apiOptions.System.Name { - case MainSystem.Name, CDSystem.Name: - return TargetHosted - } - return TargetCloud -} - -func (t *cloudTarget) Deployment() Deployment { return t.deploymentOptions.Deployment } - -func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, cluster string) (*Service, error) { - if name != DeployService && t.deploymentOptions.ClusterURLs == nil { - if err := t.waitForEndpoints(timeout, runID); err != nil { - return nil, err - } - } - switch name { - case DeployService: - return &Service{Name: name, BaseURL: t.apiOptions.System.URL, TLSOptions: t.apiOptions.TLSOptions, ztsClient: t.ztsClient}, nil - case QueryService, DocumentService: - url, err := t.resolveEndpoint(cluster) - if err != nil { - return nil, err - } - t.deploymentOptions.TLSOptions.AthenzDomain = t.apiOptions.System.AthenzDomain - return &Service{Name: name, BaseURL: url, TLSOptions: t.deploymentOptions.TLSOptions, ztsClient: t.ztsClient}, nil - } - return nil, fmt.Errorf("unknown service: %s", name) -} - -func (t *cloudTarget) SignRequest(req *http.Request, keyID string) error { - if t.apiOptions.System.IsPublic() { - if t.apiOptions.APIKey != nil { - signer := NewRequestSigner(keyID, t.apiOptions.APIKey) - return signer.SignRequest(req) - } else { - return t.addAuth0AccessToken(req) - } - } else { - if t.apiOptions.TLSOptions.KeyPair.Certificate == nil { - return fmt.Errorf("system %s requires a certificate for authentication", t.apiOptions.System.Name) - } - return nil - } -} - -func (t *cloudTarget) CheckVersion(clientVersion version.Version) error { - if clientVersion.IsZero() { // development version is always fine - return nil - } - req, err := http.NewRequest("GET", fmt.Sprintf("%s/cli/v1/", t.apiOptions.System.URL), nil) - if err != nil { - return err - } - response, err := util.HttpDo(req, 10*time.Second, "") - if err != nil { - return err - } - defer response.Body.Close() - var cliResponse struct { - MinVersion string `json:"minVersion"` - } - dec := json.NewDecoder(response.Body) - if err := dec.Decode(&cliResponse); err != nil { - return err - } - minVersion, err := version.Parse(cliResponse.MinVersion) - if err != nil { - return err - } - if clientVersion.Less(minVersion) { - return fmt.Errorf("client version %s is less than the minimum supported version: %s", clientVersion, minVersion) - } - return nil -} - -func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error { - a, err := auth0.GetAuth0(t.apiOptions.AuthConfigPath, t.apiOptions.System.Name, t.apiOptions.System.URL) - if err != nil { - return err - } - system, err := a.PrepareSystem(auth0.ContextWithCancel()) - if err != nil { - return err - } - request.Header.Set("Authorization", "Bearer "+system.AccessToken) - return nil -} - -func (t *cloudTarget) logsURL() string { - return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs", - t.apiOptions.System.URL, - t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, - t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region) -} - -func (t *cloudTarget) PrintLog(options LogOptions) error { - req, err := http.NewRequest("GET", t.logsURL(), nil) - if err != nil { - return err - } - lastFrom := options.From - requestFunc := func() *http.Request { - fromMillis := lastFrom.Unix() * 1000 - q := req.URL.Query() - q.Set("from", strconv.FormatInt(fromMillis, 10)) - if !options.To.IsZero() { - toMillis := options.To.Unix() * 1000 - q.Set("to", strconv.FormatInt(toMillis, 10)) - } - req.URL.RawQuery = q.Encode() - t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()) - return req - } - logFunc := func(status int, response []byte) (bool, error) { - if ok, err := isOK(status); !ok { - return ok, err - } - logEntries, err := ReadLogEntries(bytes.NewReader(response)) - if err != nil { - return true, err - } - for _, le := range logEntries { - if !le.Time.After(lastFrom) { - continue - } - if LogLevel(le.Level) > options.Level { - continue - } - fmt.Fprintln(options.Writer, le.Format(options.Dequote)) - } - if len(logEntries) > 0 { - lastFrom = logEntries[len(logEntries)-1].Time - } - return false, nil - } - var timeout time.Duration - if options.Follow { - timeout = math.MaxInt64 // No timeout - } - _, err = wait(logFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout) - return err -} - -func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error { - if runID > 0 { - if err := t.waitForRun(runID, timeout); err != nil { - return err - } - } - return t.discoverEndpoints(timeout) -} - -func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { - runURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/job/%s-%s/run/%d", - t.apiOptions.System.URL, - t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, - t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region, runID) - req, err := http.NewRequest("GET", runURL, nil) - if err != nil { - return err - } - lastID := int64(-1) - requestFunc := func() *http.Request { - q := req.URL.Query() - q.Set("after", strconv.FormatInt(lastID, 10)) - req.URL.RawQuery = q.Encode() - if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { - panic(err) - } - return req - } - jobSuccessFunc := func(status int, response []byte) (bool, error) { - if ok, err := isOK(status); !ok { - return ok, err - } - var resp jobResponse - if err := json.Unmarshal(response, &resp); err != nil { - return false, nil - } - if t.logOptions.Writer != nil { - lastID = t.printLog(resp, lastID) - } - if resp.Active { - return false, nil - } - if resp.Status != "success" { - return false, fmt.Errorf("run %d ended with unsuccessful status: %s", runID, resp.Status) - } - return true, nil - } - _, err = wait(jobSuccessFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout) - return err -} - -func (t *cloudTarget) printLog(response jobResponse, last int64) int64 { - if response.LastID == 0 { - return last - } - var msgs []logMessage - for step, stepMsgs := range response.Log { - for _, msg := range stepMsgs { - if step == "copyVespaLogs" && LogLevel(msg.Type) > t.logOptions.Level || LogLevel(msg.Type) == 3 { - continue - } - msgs = append(msgs, msg) - } - } - sort.Slice(msgs, func(i, j int) bool { return msgs[i].At < msgs[j].At }) - for _, msg := range msgs { - tm := time.Unix(msg.At/1000, (msg.At%1000)*1000) - fmtTime := tm.Format("15:04:05") - fmt.Fprintf(t.logOptions.Writer, "[%s] %-7s %s\n", fmtTime, msg.Type, msg.Message) - } - return response.LastID -} - -func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { - deploymentURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s", - t.apiOptions.System.URL, - t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, - t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region) - req, err := http.NewRequest("GET", deploymentURL, nil) - if err != nil { - return err - } - if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { - return err - } - urlsByCluster := make(map[string]string) - endpointFunc := func(status int, response []byte) (bool, error) { - if ok, err := isOK(status); !ok { - return ok, err - } - var resp deploymentResponse - if err := json.Unmarshal(response, &resp); err != nil { - return false, nil - } - if len(resp.Endpoints) == 0 { - return false, nil - } - for _, endpoint := range resp.Endpoints { - if endpoint.Scope != "zone" { - continue - } - urlsByCluster[endpoint.Cluster] = endpoint.URL - } - return true, nil - } - if _, err = wait(endpointFunc, func() *http.Request { return req }, &t.apiOptions.TLSOptions.KeyPair, timeout); err != nil { - return err - } - if len(urlsByCluster) == 0 { - return fmt.Errorf("no endpoints discovered") - } - t.deploymentOptions.ClusterURLs = urlsByCluster - return nil -} - -func isOK(status int) (bool, error) { - if status == 401 { - return false, fmt.Errorf("status %d: invalid api key", status) - } - return status/100 == 2, nil -} - -// LocalTarget creates a target for a Vespa platform running locally. -func LocalTarget() Target { - return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1"} -} - -// CustomTarget creates a Target for a Vespa platform running at baseURL. -func CustomTarget(baseURL string) Target { - return &customTarget{targetType: TargetCustom, baseURL: baseURL} -} - -// CloudTarget creates a Target for the Vespa Cloud or hosted Vespa platform. -func CloudTarget(apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) { - ztsClient, err := zts.NewClient(zts.DefaultURL, util.ActiveHttpClient) - if err != nil { - return nil, err + return 0, err } - return &cloudTarget{ - apiOptions: apiOptions, - deploymentOptions: deploymentOptions, - logOptions: logOptions, - ztsClient: ztsClient, - }, nil -} - -type deploymentEndpoint struct { - Cluster string `json:"cluster"` - URL string `json:"url"` - Scope string `json:"scope"` -} - -type deploymentResponse struct { - Endpoints []deploymentEndpoint `json:"endpoints"` + okFunc := func(status int, response []byte) (bool, error) { return isOK(status), nil } + return wait(okFunc, func() *http.Request { return req }, certificate, timeout) } -type serviceConvergeResponse struct { - Converged bool `json:"converged"` -} - -type jobResponse struct { - Active bool `json:"active"` - Status string `json:"status"` - Log map[string][]logMessage `json:"log"` - LastID int64 `json:"lastId"` -} - -type logMessage struct { - At int64 `json:"at"` - Type string `json:"type"` - Message string `json:"message"` -} - -type responseFunc func(status int, response []byte) (bool, error) - -type requestFunc func() *http.Request - func wait(fn responseFunc, reqFn requestFunc, certificate *tls.Certificate, timeout time.Duration) (int, error) { if certificate != nil { util.ActiveHttpClient.UseCertificate([]tls.Certificate{*certificate}) diff --git a/client/go/vespa/target_cloud.go b/client/go/vespa/target_cloud.go new file mode 100644 index 00000000000..f4eccaacab6 --- /dev/null +++ b/client/go/vespa/target_cloud.go @@ -0,0 +1,382 @@ +package vespa + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "math" + "net/http" + "sort" + "strconv" + "time" + + "github.com/vespa-engine/vespa/client/go/auth0" + "github.com/vespa-engine/vespa/client/go/util" + "github.com/vespa-engine/vespa/client/go/version" + "github.com/vespa-engine/vespa/client/go/zts" +) + +// CloudOptions configures URL and authentication for a cloud target. +type APIOptions struct { + System System + TLSOptions TLSOptions + APIKey []byte + AuthConfigPath string +} + +// CloudDeploymentOptions configures the deployment to manage through a cloud target. +type CloudDeploymentOptions struct { + Deployment Deployment + TLSOptions TLSOptions + ClusterURLs map[string]string // Endpoints keyed on cluster name +} + +type cloudTarget struct { + apiOptions APIOptions + deploymentOptions CloudDeploymentOptions + logOptions LogOptions + ztsClient ztsClient +} + +type deploymentEndpoint struct { + Cluster string `json:"cluster"` + URL string `json:"url"` + Scope string `json:"scope"` +} + +type deploymentResponse struct { + Endpoints []deploymentEndpoint `json:"endpoints"` +} + +type jobResponse struct { + Active bool `json:"active"` + Status string `json:"status"` + Log map[string][]logMessage `json:"log"` + LastID int64 `json:"lastId"` +} + +type logMessage struct { + At int64 `json:"at"` + Type string `json:"type"` + Message string `json:"message"` +} + +type ztsClient interface { + AccessToken(domain string, certficiate tls.Certificate) (string, error) +} + +// CloudTarget creates a Target for the Vespa Cloud or hosted Vespa platform. +func CloudTarget(apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) { + ztsClient, err := zts.NewClient(zts.DefaultURL, util.ActiveHttpClient) + if err != nil { + return nil, err + } + return &cloudTarget{ + apiOptions: apiOptions, + deploymentOptions: deploymentOptions, + logOptions: logOptions, + ztsClient: ztsClient, + }, nil +} + +func (t *cloudTarget) resolveEndpoint(cluster string) (string, error) { + if cluster == "" { + for _, u := range t.deploymentOptions.ClusterURLs { + if len(t.deploymentOptions.ClusterURLs) == 1 { + return u, nil + } else { + return "", fmt.Errorf("multiple clusters, none chosen: %v", t.deploymentOptions.ClusterURLs) + } + } + } else { + u := t.deploymentOptions.ClusterURLs[cluster] + if u == "" { + clusters := make([]string, len(t.deploymentOptions.ClusterURLs)) + for c := range t.deploymentOptions.ClusterURLs { + clusters = append(clusters, c) + } + return "", fmt.Errorf("unknown cluster '%s': must be one of %v", cluster, clusters) + } + return u, nil + } + + return "", fmt.Errorf("no endpoints") +} + +func (t *cloudTarget) Type() string { + switch t.apiOptions.System.Name { + case MainSystem.Name, CDSystem.Name: + return TargetHosted + } + return TargetCloud +} + +func (t *cloudTarget) Deployment() Deployment { return t.deploymentOptions.Deployment } + +func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, cluster string) (*Service, error) { + switch name { + case DeployService: + service := &Service{Name: name, BaseURL: t.apiOptions.System.URL, TLSOptions: t.apiOptions.TLSOptions, ztsClient: t.ztsClient} + if timeout > 0 { + status, err := service.Wait(timeout) + if err != nil { + return nil, err + } + if !isOK(status) { + return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL) + } + } + return service, nil + case QueryService, DocumentService: + if t.deploymentOptions.ClusterURLs == nil { + if err := t.waitForEndpoints(timeout, runID); err != nil { + return nil, err + } + } + url, err := t.resolveEndpoint(cluster) + if err != nil { + return nil, err + } + t.deploymentOptions.TLSOptions.AthenzDomain = t.apiOptions.System.AthenzDomain + return &Service{Name: name, BaseURL: url, TLSOptions: t.deploymentOptions.TLSOptions, ztsClient: t.ztsClient}, nil + } + return nil, fmt.Errorf("unknown service: %s", name) +} + +func (t *cloudTarget) SignRequest(req *http.Request, keyID string) error { + if t.apiOptions.System.IsPublic() { + if t.apiOptions.APIKey != nil { + signer := NewRequestSigner(keyID, t.apiOptions.APIKey) + return signer.SignRequest(req) + } else { + return t.addAuth0AccessToken(req) + } + } else { + if t.apiOptions.TLSOptions.KeyPair.Certificate == nil { + return fmt.Errorf("system %s requires a certificate for authentication", t.apiOptions.System.Name) + } + return nil + } +} + +func (t *cloudTarget) CheckVersion(clientVersion version.Version) error { + if clientVersion.IsZero() { // development version is always fine + return nil + } + req, err := http.NewRequest("GET", fmt.Sprintf("%s/cli/v1/", t.apiOptions.System.URL), nil) + if err != nil { + return err + } + response, err := util.HttpDo(req, 10*time.Second, "") + if err != nil { + return err + } + defer response.Body.Close() + var cliResponse struct { + MinVersion string `json:"minVersion"` + } + dec := json.NewDecoder(response.Body) + if err := dec.Decode(&cliResponse); err != nil { + return err + } + minVersion, err := version.Parse(cliResponse.MinVersion) + if err != nil { + return err + } + if clientVersion.Less(minVersion) { + return fmt.Errorf("client version %s is less than the minimum supported version: %s", clientVersion, minVersion) + } + return nil +} + +func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error { + a, err := auth0.GetAuth0(t.apiOptions.AuthConfigPath, t.apiOptions.System.Name, t.apiOptions.System.URL) + if err != nil { + return err + } + system, err := a.PrepareSystem(auth0.ContextWithCancel()) + if err != nil { + return err + } + request.Header.Set("Authorization", "Bearer "+system.AccessToken) + return nil +} + +func (t *cloudTarget) logsURL() string { + return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs", + t.apiOptions.System.URL, + t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, + t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region) +} + +func (t *cloudTarget) PrintLog(options LogOptions) error { + req, err := http.NewRequest("GET", t.logsURL(), nil) + if err != nil { + return err + } + lastFrom := options.From + requestFunc := func() *http.Request { + fromMillis := lastFrom.Unix() * 1000 + q := req.URL.Query() + q.Set("from", strconv.FormatInt(fromMillis, 10)) + if !options.To.IsZero() { + toMillis := options.To.Unix() * 1000 + q.Set("to", strconv.FormatInt(toMillis, 10)) + } + req.URL.RawQuery = q.Encode() + t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()) + return req + } + logFunc := func(status int, response []byte) (bool, error) { + if ok, err := isCloudOK(status); !ok { + return ok, err + } + logEntries, err := ReadLogEntries(bytes.NewReader(response)) + if err != nil { + return true, err + } + for _, le := range logEntries { + if !le.Time.After(lastFrom) { + continue + } + if LogLevel(le.Level) > options.Level { + continue + } + fmt.Fprintln(options.Writer, le.Format(options.Dequote)) + } + if len(logEntries) > 0 { + lastFrom = logEntries[len(logEntries)-1].Time + } + return false, nil + } + var timeout time.Duration + if options.Follow { + timeout = math.MaxInt64 // No timeout + } + _, err = wait(logFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout) + return err +} + +func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error { + if runID > 0 { + if err := t.waitForRun(runID, timeout); err != nil { + return err + } + } + return t.discoverEndpoints(timeout) +} + +func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { + runURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/job/%s-%s/run/%d", + t.apiOptions.System.URL, + t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, + t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region, runID) + req, err := http.NewRequest("GET", runURL, nil) + if err != nil { + return err + } + lastID := int64(-1) + requestFunc := func() *http.Request { + q := req.URL.Query() + q.Set("after", strconv.FormatInt(lastID, 10)) + req.URL.RawQuery = q.Encode() + if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { + panic(err) + } + return req + } + jobSuccessFunc := func(status int, response []byte) (bool, error) { + if ok, err := isCloudOK(status); !ok { + return ok, err + } + var resp jobResponse + if err := json.Unmarshal(response, &resp); err != nil { + return false, nil + } + if t.logOptions.Writer != nil { + lastID = t.printLog(resp, lastID) + } + if resp.Active { + return false, nil + } + if resp.Status != "success" { + return false, fmt.Errorf("run %d ended with unsuccessful status: %s", runID, resp.Status) + } + return true, nil + } + _, err = wait(jobSuccessFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout) + return err +} + +func (t *cloudTarget) printLog(response jobResponse, last int64) int64 { + if response.LastID == 0 { + return last + } + var msgs []logMessage + for step, stepMsgs := range response.Log { + for _, msg := range stepMsgs { + if step == "copyVespaLogs" && LogLevel(msg.Type) > t.logOptions.Level || LogLevel(msg.Type) == 3 { + continue + } + msgs = append(msgs, msg) + } + } + sort.Slice(msgs, func(i, j int) bool { return msgs[i].At < msgs[j].At }) + for _, msg := range msgs { + tm := time.Unix(msg.At/1000, (msg.At%1000)*1000) + fmtTime := tm.Format("15:04:05") + fmt.Fprintf(t.logOptions.Writer, "[%s] %-7s %s\n", fmtTime, msg.Type, msg.Message) + } + return response.LastID +} + +func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { + deploymentURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s", + t.apiOptions.System.URL, + t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance, + t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region) + req, err := http.NewRequest("GET", deploymentURL, nil) + if err != nil { + return err + } + if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { + return err + } + urlsByCluster := make(map[string]string) + endpointFunc := func(status int, response []byte) (bool, error) { + if ok, err := isCloudOK(status); !ok { + return ok, err + } + var resp deploymentResponse + if err := json.Unmarshal(response, &resp); err != nil { + return false, nil + } + if len(resp.Endpoints) == 0 { + return false, nil + } + for _, endpoint := range resp.Endpoints { + if endpoint.Scope != "zone" { + continue + } + urlsByCluster[endpoint.Cluster] = endpoint.URL + } + return true, nil + } + if _, err = wait(endpointFunc, func() *http.Request { return req }, &t.apiOptions.TLSOptions.KeyPair, timeout); err != nil { + return err + } + if len(urlsByCluster) == 0 { + return fmt.Errorf("no endpoints discovered") + } + t.deploymentOptions.ClusterURLs = urlsByCluster + return nil +} + +func isCloudOK(status int) (bool, error) { + if status == 401 { + // when retrying we should give up immediately if we're not authorized + return false, fmt.Errorf("status %d: invalid credentials", status) + } + return isOK(status), nil +} diff --git a/client/go/vespa/target_custom.go b/client/go/vespa/target_custom.go new file mode 100644 index 00000000000..072ec8649e4 --- /dev/null +++ b/client/go/vespa/target_custom.go @@ -0,0 +1,128 @@ +package vespa + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/vespa-engine/vespa/client/go/version" +) + +type customTarget struct { + targetType string + baseURL string +} + +type serviceConvergeResponse struct { + Converged bool `json:"converged"` +} + +// LocalTarget creates a target for a Vespa platform running locally. +func LocalTarget() Target { + return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1"} +} + +// CustomTarget creates a Target for a Vespa platform running at baseURL. +func CustomTarget(baseURL string) Target { + return &customTarget{targetType: TargetCustom, baseURL: baseURL} +} + +func (t *customTarget) Type() string { return t.targetType } + +func (t *customTarget) Deployment() Deployment { return Deployment{} } + +func (t *customTarget) createService(name string) (*Service, error) { + switch name { + case DeployService, QueryService, DocumentService: + url, err := t.urlWithPort(name) + if err != nil { + return nil, err + } + return &Service{BaseURL: url, Name: name}, nil + } + return nil, fmt.Errorf("unknown service: %s", name) +} + +func (t *customTarget) Service(name string, timeout time.Duration, sessionOrRunID int64, cluster string) (*Service, error) { + service, err := t.createService(name) + if err != nil { + return nil, err + } + if timeout > 0 { + if name == DeployService { + status, err := service.Wait(timeout) + if err != nil { + return nil, err + } + if !isOK(status) { + return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL) + } + } else { + if err := t.waitForConvergence(timeout); err != nil { + return nil, err + } + } + } + return service, nil +} + +func (t *customTarget) PrintLog(options LogOptions) error { + return fmt.Errorf("reading logs from non-cloud deployment is unsupported") +} + +func (t *customTarget) SignRequest(req *http.Request, sigKeyId string) error { return nil } + +func (t *customTarget) CheckVersion(version version.Version) error { return nil } + +func (t *customTarget) urlWithPort(serviceName string) (string, error) { + u, err := url.Parse(t.baseURL) + if err != nil { + return "", err + } + port := u.Port() + if port == "" { + switch serviceName { + case DeployService: + port = "19071" + case QueryService, DocumentService: + port = "8080" + default: + return "", fmt.Errorf("unknown service: %s", serviceName) + } + u.Host = u.Host + ":" + port + } + return u.String(), nil +} + +func (t *customTarget) waitForConvergence(timeout time.Duration) error { + deployURL, err := t.urlWithPort(DeployService) + if err != nil { + return err + } + url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployURL) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return err + } + converged := false + convergedFunc := func(status int, response []byte) (bool, error) { + if !isOK(status) { + return false, nil + } + var resp serviceConvergeResponse + if err := json.Unmarshal(response, &resp); err != nil { + return false, nil + } + converged = resp.Converged + return converged, nil + } + if _, err := wait(convergedFunc, func() *http.Request { return req }, nil, timeout); err != nil { + return err + } + if !converged { + return fmt.Errorf("services have not converged") + } + return nil +} diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java index e0e4318ccc8..7cb36374568 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java @@ -116,6 +116,7 @@ public interface ModelContext { @ModelFeatureFlag(owners = {"arnej"}) default boolean useQrserverServiceName() { return true; } @ModelFeatureFlag(owners = {"bjorncs", "baldersheim"}) default boolean enableJdiscPreshutdownCommand() { return true; } @ModelFeatureFlag(owners = {"arnej"}) default boolean avoidRenamingSummaryFeatures() { return false; } + @ModelFeatureFlag(owners = {"bjorncs", "baldersheim"}) default boolean mergeGroupingResultInSearchInvoker() { return false; } } /** Warning: As elsewhere in this package, do not make backwards incompatible changes that will break old config models! */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java index 22b752777e9..e483351a25a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java @@ -158,7 +158,7 @@ public class ContentSearchCluster extends AbstractConfigProducer<SearchCluster> String clusterName, ContentSearchCluster search) { List<ModelElement> indexedDefs = getIndexedSchemas(clusterElem); if (!indexedDefs.isEmpty()) { - IndexedSearchCluster isc = new IndexedSearchCluster(search, clusterName, 0); + IndexedSearchCluster isc = new IndexedSearchCluster(deployState, search, clusterName, 0); isc.setRoutingSelector(clusterElem.childAsString("documents.selection")); Double visibilityDelay = clusterElem.childAsDouble("engine.proton.visibility-delay"); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java index fb7c6696b54..53aac23135a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java @@ -6,7 +6,6 @@ import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig; import com.yahoo.search.config.IndexInfoConfig; import com.yahoo.searchdefinition.DocumentOnlySchema; -import com.yahoo.searchdefinition.Schema; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.vespa.config.search.DispatchConfig; @@ -34,6 +33,7 @@ public class IndexedSearchCluster extends SearchCluster IlscriptsConfig.Producer, DispatchConfig.Producer { + private final boolean mergeGroupingResultInSearchInvoker; private String indexingClusterName = null; // The name of the docproc cluster to run indexing, by config. private String indexingChainName = null; @@ -63,10 +63,11 @@ public class IndexedSearchCluster extends SearchCluster return routingSelector; } - public IndexedSearchCluster(AbstractConfigProducer<SearchCluster> parent, String clusterName, int index) { + public IndexedSearchCluster(DeployState deployState, AbstractConfigProducer<SearchCluster> parent, String clusterName, int index) { super(parent, clusterName, index); unionCfg = new UnionConfiguration(this, documentDbs); rootDispatch = new DispatchGroup(this); + mergeGroupingResultInSearchInvoker = deployState.featureFlags().mergeGroupingResultInSearchInvoker(); } @Override @@ -320,6 +321,7 @@ public class IndexedSearchCluster extends SearchCluster builder.maxWaitAfterCoverageFactor(searchCoverage.getMaxWaitAfterCoverageFactor()); } builder.warmuptime(5.0); + builder.mergeGroupingResultInSearchInvokerEnabled(mergeGroupingResultInSearchInvoker); } @Override diff --git a/configdefinitions/src/vespa/dispatch.def b/configdefinitions/src/vespa/dispatch.def index 17f42a73bfd..fef9300a410 100644 --- a/configdefinitions/src/vespa/dispatch.def +++ b/configdefinitions/src/vespa/dispatch.def @@ -71,3 +71,6 @@ node[].host string # The rpc port of this search node node[].port int + +# Temporary feature flag +mergeGroupingResultInSearchInvokerEnabled bool default=false diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index 5a813c7886a..6222e7b1788 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -209,6 +209,7 @@ public class ModelContextImpl implements ModelContext { private final boolean inhibitDefaultMergesWhenGlobalMergesPending; private final boolean useQrserverServiceName; private final boolean avoidRenamingSummaryFeatures; + private final boolean mergeGroupingResultInSearchInvoker; public FeatureFlags(FlagSource source, ApplicationId appId) { this.defaultTermwiseLimit = flagValue(source, appId, Flags.DEFAULT_TERM_WISE_LIMIT); @@ -256,6 +257,7 @@ public class ModelContextImpl implements ModelContext { this.inhibitDefaultMergesWhenGlobalMergesPending = flagValue(source, appId, Flags.INHIBIT_DEFAULT_MERGES_WHEN_GLOBAL_MERGES_PENDING); this.useQrserverServiceName = flagValue(source, appId, Flags.USE_QRSERVER_SERVICE_NAME); this.avoidRenamingSummaryFeatures = flagValue(source, appId, Flags.AVOID_RENAMING_SUMMARY_FEATURES); + this.mergeGroupingResultInSearchInvoker = flagValue(source, appId, Flags.MERGE_GROUPING_RESULT_IN_SEARCH_INVOKER); } @Override public double defaultTermwiseLimit() { return defaultTermwiseLimit; } @@ -305,6 +307,7 @@ public class ModelContextImpl implements ModelContext { @Override public boolean inhibitDefaultMergesWhenGlobalMergesPending() { return inhibitDefaultMergesWhenGlobalMergesPending; } @Override public boolean useQrserverServiceName() { return useQrserverServiceName; } @Override public boolean avoidRenamingSummaryFeatures() { return avoidRenamingSummaryFeatures; } + @Override public boolean mergeGroupingResultInSearchInvoker() { return mergeGroupingResultInSearchInvoker; } private static <V> V flagValue(FlagSource source, ApplicationId appId, UnboundFlag<? extends V, ?, ?> flag) { return flag.bindTo(source) diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/GroupingResultAggregator.java b/container-search/src/main/java/com/yahoo/search/dispatch/GroupingResultAggregator.java new file mode 100644 index 00000000000..5ce7accfdd4 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/GroupingResultAggregator.java @@ -0,0 +1,50 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import com.yahoo.prelude.fastsearch.DocsumDefinitionSet; +import com.yahoo.prelude.fastsearch.GroupingListHit; +import com.yahoo.searchlib.aggregation.Grouping; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Incrementally merges underlying {@link Grouping} instances from {@link GroupingListHit} hits. + * + * @author bjorncs + */ +class GroupingResultAggregator { + private static final Logger log = Logger.getLogger(GroupingResultAggregator.class.getName()); + + private final Map<Integer, Grouping> groupings = new LinkedHashMap<>(); + private DocsumDefinitionSet documentDefinitions = null; + private int groupingHitsMerged = 0; + + void mergeWith(GroupingListHit result) { + if (groupingHitsMerged == 0) documentDefinitions = result.getDocsumDefinitionSet(); + ++groupingHitsMerged; + log.log(Level.FINE, () -> + String.format("Merging hit #%d having %d groupings", + groupingHitsMerged, result.getGroupingList().size())); + for (Grouping grouping : result.getGroupingList()) { + groupings.merge(grouping.getId(), grouping, (existingGrouping, newGrouping) -> { + existingGrouping.merge(newGrouping); + return existingGrouping; + }); + } + } + + Optional<GroupingListHit> toAggregatedHit() { + if (groupingHitsMerged == 0) return Optional.empty(); + log.log(Level.FINE, () -> + String.format("Creating aggregated hit containing %d groupings from %d hits", + groupings.size(), groupingHitsMerged)); + groupings.values().forEach(Grouping::postMerge); + return Optional.of(new GroupingListHit(List.copyOf(groupings.values()), documentDefinitions)); + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java index 4e658122cdf..d7c9f1dce53 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.dispatch; +import com.yahoo.prelude.fastsearch.GroupingListHit; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.dispatch.searchcluster.Group; @@ -44,6 +45,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM private final Group group; private final LinkedBlockingQueue<SearchInvoker> availableForProcessing; private final Set<Integer> alreadyFailedNodes; + private final boolean mergeGroupingResult; private Query query; private boolean adaptiveTimeoutCalculated = false; @@ -71,6 +73,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM this.group = group; this.availableForProcessing = newQueue(); this.alreadyFailedNodes = alreadyFailedNodes; + this.mergeGroupingResult = searchCluster.dispatchConfig().mergeGroupingResultInSearchInvokerEnabled(); } /** @@ -115,6 +118,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM long nextTimeout = query.getTimeLeft(); boolean extraDebug = (query.getOffset() == 0) && (query.getHits() == 7) && log.isLoggable(java.util.logging.Level.FINE); List<InvokerResult> processed = new ArrayList<>(); + var groupingResultAggregator = new GroupingResultAggregator(); try { while (!invokers.isEmpty() && nextTimeout >= 0) { SearchInvoker invoker = availableForProcessing.poll(nextTimeout, TimeUnit.MILLISECONDS); @@ -126,7 +130,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM if (extraDebug) { processed.add(toMerge); } - merged = mergeResult(result.getResult(), toMerge, merged); + merged = mergeResult(result.getResult(), toMerge, merged, groupingResultAggregator); ejectInvoker(invoker); } nextTimeout = nextTimeout(); @@ -134,6 +138,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM } catch (InterruptedException e) { throw new RuntimeException("Interrupted while waiting for search results", e); } + groupingResultAggregator.toAggregatedHit().ifPresent(h -> result.getResult().hits().add(h)); insertNetworkErrors(result.getResult()); result.getResult().setCoverage(createCoverage()); @@ -238,14 +243,20 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM return nextAdaptive; } - private List<LeanHit> mergeResult(Result result, InvokerResult partialResult, List<LeanHit> current) { + private List<LeanHit> mergeResult(Result result, InvokerResult partialResult, List<LeanHit> current, + GroupingResultAggregator groupingResultAggregator) { collectCoverage(partialResult.getResult().getCoverage(true)); result.mergeWith(partialResult.getResult()); List<Hit> partialNonLean = partialResult.getResult().hits().asUnorderedHits(); for(Hit hit : partialNonLean) { if (hit.isAuxiliary()) { - result.hits().add(hit); + if (hit instanceof GroupingListHit && mergeGroupingResult) { + var groupingHit = (GroupingListHit) hit; + groupingResultAggregator.mergeWith(groupingHit); + } else { + result.hits().add(hit); + } } } if (current.isEmpty() ) { diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java index 6e08b1c6fa5..347276d680d 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java @@ -4,6 +4,7 @@ package com.yahoo.search.dispatch; import com.yahoo.document.GlobalId; import com.yahoo.document.idstring.IdString; import com.yahoo.prelude.fastsearch.FastHit; +import com.yahoo.prelude.fastsearch.GroupingListHit; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.dispatch.searchcluster.Group; @@ -13,6 +14,11 @@ import com.yahoo.search.result.Coverage; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.Hit; import com.yahoo.search.result.Relevance; +import com.yahoo.searchlib.aggregation.Grouping; +import com.yahoo.searchlib.aggregation.MaxAggregationResult; +import com.yahoo.searchlib.aggregation.MinAggregationResult; +import com.yahoo.searchlib.expression.IntegerResultNode; +import com.yahoo.searchlib.expression.StringResultNode; import com.yahoo.test.ManualClock; import org.junit.Test; @@ -320,6 +326,39 @@ public class InterleavedSearchInvokerTest { assertEquals(3, result.getQuery().getHits()); } + @Test + public void requireThatGroupingsAreMerged() throws IOException { + SearchCluster cluster = new MockSearchCluster("!", 1, 2); + List<SearchInvoker> invokers = new ArrayList<>(); + + Grouping grouping1 = new Grouping(0); + grouping1.setRoot(new com.yahoo.searchlib.aggregation.Group() + .addChild(new com.yahoo.searchlib.aggregation.Group() + .setId(new StringResultNode("uniqueA")) + .addAggregationResult(new MaxAggregationResult().setMax(new IntegerResultNode(6)).setTag(4))) + .addChild(new com.yahoo.searchlib.aggregation.Group() + .setId(new StringResultNode("common")) + .addAggregationResult(new MaxAggregationResult().setMax(new IntegerResultNode(9)).setTag(4)))); + invokers.add(new MockInvoker(0).setHits(List.of(new GroupingListHit(List.of(grouping1))))); + + Grouping grouping2 = new Grouping(0); + grouping2.setRoot(new com.yahoo.searchlib.aggregation.Group() + .addChild(new com.yahoo.searchlib.aggregation.Group() + .setId(new StringResultNode("uniqueB")) + .addAggregationResult(new MaxAggregationResult().setMax(new IntegerResultNode(9)).setTag(4))) + .addChild(new com.yahoo.searchlib.aggregation.Group() + .setId(new StringResultNode("common")) + .addAggregationResult(new MinAggregationResult().setMin(new IntegerResultNode(6)).setTag(3)))); + invokers.add(new MockInvoker(0).setHits(List.of(new GroupingListHit(List.of(grouping2))))); + + InterleavedSearchInvoker invoker = new InterleavedSearchInvoker(invokers, cluster, new Group(0, List.of()), Collections.emptySet()); + invoker.responseAvailable(invokers.get(0)); + invoker.responseAvailable(invokers.get(1)); + Result result = invoker.search(query, null); + assertEquals(1, ((GroupingListHit) result.hits().get(0)).getGroupingList().size()); + + } + private static InterleavedSearchInvoker createInterLeavedTestInvoker(List<Double> a, List<Double> b, Group group) { SearchCluster cluster = new MockSearchCluster("!", 1, 2); List<SearchInvoker> invokers = new ArrayList<>(); diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java index 2acce0f8d2d..54c8c1e0522 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java @@ -120,7 +120,8 @@ public class MockSearchCluster extends SearchCluster { builder.minActivedocsPercentage(88.0); builder.minGroupCoverage(99.0); builder.minSearchCoverage(minSearchCoverage); - builder.distributionPolicy(DispatchConfig.DistributionPolicy.Enum.ROUNDROBIN); + builder.distributionPolicy(DispatchConfig.DistributionPolicy.Enum.ROUNDROBIN) + .mergeGroupingResultInSearchInvokerEnabled(true); if (minSearchCoverage < 100.0) { builder.minWaitAfterCoverageFactor(0); builder.maxWaitAfterCoverageFactor(0.5); diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 881b62d1e04..5376fa983af 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -390,6 +390,13 @@ public class Flags { "Takes effect immediately", TENANT_ID); + public static final UnboundBooleanFlag MERGE_GROUPING_RESULT_IN_SEARCH_INVOKER = defineFeatureFlag( + "merge-grouping-result-in-search-invoker", false, + List.of("bjorncs", "baldersheim"), "2022-02-23", "2022-08-01", + "Merge grouping results incrementally in interleaved search invoker", + "Takes effect at redeployment", + APPLICATION_ID); + /** WARNING: public for testing: All flags should be defined in {@link Flags}. */ public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, List<String> owners, String createdAt, String expiresAt, String description, diff --git a/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java index 25b3cb18ff9..c88a567c559 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java @@ -46,6 +46,8 @@ public class Grouping extends Identifiable { // Actual root group, does not require level details. private Group root = new Group(); + private boolean postMergeCompleted = false; + /** * <p>Constructs an empty result node. <b>NOTE:</b> This instance is broken until non-optional member data is * set.</p> @@ -78,7 +80,9 @@ public class Grouping extends Identifiable { * that might have changes due to the merge.</p> */ public void postMerge() { + if (postMergeCompleted) return; root.postMerge(groupingLevels, firstLevel, 0); + postMergeCompleted = true; } /** diff --git a/vespalib/src/tests/util/generationhandler/CMakeLists.txt b/vespalib/src/tests/util/generationhandler/CMakeLists.txt index 677d5caa0e6..fdf54c59854 100644 --- a/vespalib/src/tests/util/generationhandler/CMakeLists.txt +++ b/vespalib/src/tests/util/generationhandler/CMakeLists.txt @@ -4,5 +4,6 @@ vespa_add_executable(vespalib_generationhandler_test_app TEST generationhandler_test.cpp DEPENDS vespalib + GTest::GTest ) vespa_add_test(NAME vespalib_generationhandler_test_app COMMAND vespalib_generationhandler_test_app) diff --git a/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp b/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp index f269fe729fa..00da752a749 100644 --- a/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp +++ b/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp @@ -1,157 +1,137 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/vespalib/testkit/testapp.h> +#include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/util/generationhandler.h> #include <deque> namespace vespalib { -typedef GenerationHandler::Guard GenGuard; +using GenGuard = GenerationHandler::Guard; -class Test : public vespalib::TestApp { -private: - void requireThatGenerationCanBeIncreased(); - void requireThatReadersCanTakeGuards(); - void requireThatGuardsCanBeCopied(); - void requireThatTheFirstUsedGenerationIsCorrect(); - void requireThatGenerationCanGrowLarge(); -public: - int Main() override; +class GenerationHandlerTest : public ::testing::Test { +protected: + GenerationHandler gh; + GenerationHandlerTest(); + ~GenerationHandlerTest() override; }; -void -Test::requireThatGenerationCanBeIncreased() +GenerationHandlerTest::GenerationHandlerTest() + : ::testing::Test(), + gh() { - GenerationHandler gh; - EXPECT_EQUAL(0u, gh.getCurrentGeneration()); - EXPECT_EQUAL(0u, gh.getFirstUsedGeneration()); +} + +GenerationHandlerTest::~GenerationHandlerTest() = default; + +TEST_F(GenerationHandlerTest, require_that_generation_can_be_increased) +{ + EXPECT_EQ(0u, gh.getCurrentGeneration()); + EXPECT_EQ(0u, gh.getFirstUsedGeneration()); gh.incGeneration(); - EXPECT_EQUAL(1u, gh.getCurrentGeneration()); - EXPECT_EQUAL(1u, gh.getFirstUsedGeneration()); + EXPECT_EQ(1u, gh.getCurrentGeneration()); + EXPECT_EQ(1u, gh.getFirstUsedGeneration()); } -void -Test::requireThatReadersCanTakeGuards() +TEST_F(GenerationHandlerTest, require_that_readers_can_take_guards) { - GenerationHandler gh; - EXPECT_EQUAL(0u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(0)); { GenGuard g1 = gh.takeGuard(); - EXPECT_EQUAL(1u, gh.getGenerationRefCount(0)); + EXPECT_EQ(1u, gh.getGenerationRefCount(0)); { GenGuard g2 = gh.takeGuard(); - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); gh.incGeneration(); { GenGuard g3 = gh.takeGuard(); - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(1u, gh.getGenerationRefCount(1)); - EXPECT_EQUAL(3u, gh.getGenerationRefCount()); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(1u, gh.getGenerationRefCount(1)); + EXPECT_EQ(3u, gh.getGenerationRefCount()); } - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); gh.incGeneration(); { GenGuard g3 = gh.takeGuard(); - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); - EXPECT_EQUAL(1u, gh.getGenerationRefCount(2)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(1u, gh.getGenerationRefCount(2)); } - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(2)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(0u, gh.getGenerationRefCount(2)); } - EXPECT_EQUAL(1u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(2)); + EXPECT_EQ(1u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(0u, gh.getGenerationRefCount(2)); } - EXPECT_EQUAL(0u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(2)); + EXPECT_EQ(0u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(0u, gh.getGenerationRefCount(2)); } -void -Test::requireThatGuardsCanBeCopied() +TEST_F(GenerationHandlerTest, require_that_guards_can_be_copied) { - GenerationHandler gh; GenGuard g1 = gh.takeGuard(); - EXPECT_EQUAL(1u, gh.getGenerationRefCount(0)); + EXPECT_EQ(1u, gh.getGenerationRefCount(0)); GenGuard g2(g1); - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); gh.incGeneration(); GenGuard g3 = gh.takeGuard(); - EXPECT_EQUAL(2u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(1u, gh.getGenerationRefCount(1)); + EXPECT_EQ(2u, gh.getGenerationRefCount(0)); + EXPECT_EQ(1u, gh.getGenerationRefCount(1)); g3 = g2; - EXPECT_EQUAL(3u, gh.getGenerationRefCount(0)); - EXPECT_EQUAL(0u, gh.getGenerationRefCount(1)); + EXPECT_EQ(3u, gh.getGenerationRefCount(0)); + EXPECT_EQ(0u, gh.getGenerationRefCount(1)); } -void -Test::requireThatTheFirstUsedGenerationIsCorrect() +TEST_F(GenerationHandlerTest, require_that_the_first_used_generation_is_correct) { - GenerationHandler gh; - EXPECT_EQUAL(0u, gh.getFirstUsedGeneration()); + EXPECT_EQ(0u, gh.getFirstUsedGeneration()); gh.incGeneration(); - EXPECT_EQUAL(1u, gh.getFirstUsedGeneration()); + EXPECT_EQ(1u, gh.getFirstUsedGeneration()); { GenGuard g1 = gh.takeGuard(); gh.incGeneration(); - EXPECT_EQUAL(1u, gh.getGenerationRefCount()); - EXPECT_EQUAL(1u, gh.getFirstUsedGeneration()); + EXPECT_EQ(1u, gh.getGenerationRefCount()); + EXPECT_EQ(1u, gh.getFirstUsedGeneration()); } - EXPECT_EQUAL(1u, gh.getFirstUsedGeneration()); + EXPECT_EQ(1u, gh.getFirstUsedGeneration()); gh.updateFirstUsedGeneration(); // Only writer should call this - EXPECT_EQUAL(0u, gh.getGenerationRefCount()); - EXPECT_EQUAL(2u, gh.getFirstUsedGeneration()); + EXPECT_EQ(0u, gh.getGenerationRefCount()); + EXPECT_EQ(2u, gh.getFirstUsedGeneration()); { GenGuard g1 = gh.takeGuard(); gh.incGeneration(); gh.incGeneration(); - EXPECT_EQUAL(1u, gh.getGenerationRefCount()); - EXPECT_EQUAL(2u, gh.getFirstUsedGeneration()); + EXPECT_EQ(1u, gh.getGenerationRefCount()); + EXPECT_EQ(2u, gh.getFirstUsedGeneration()); { GenGuard g2 = gh.takeGuard(); - EXPECT_EQUAL(2u, gh.getFirstUsedGeneration()); + EXPECT_EQ(2u, gh.getFirstUsedGeneration()); } } - EXPECT_EQUAL(2u, gh.getFirstUsedGeneration()); + EXPECT_EQ(2u, gh.getFirstUsedGeneration()); gh.updateFirstUsedGeneration(); // Only writer should call this - EXPECT_EQUAL(0u, gh.getGenerationRefCount()); - EXPECT_EQUAL(4u, gh.getFirstUsedGeneration()); + EXPECT_EQ(0u, gh.getGenerationRefCount()); + EXPECT_EQ(4u, gh.getFirstUsedGeneration()); } -void -Test::requireThatGenerationCanGrowLarge() +TEST_F(GenerationHandlerTest, require_that_generation_can_grow_large) { - GenerationHandler gh; std::deque<GenGuard> guards; for (size_t i = 0; i < 10000; ++i) { - EXPECT_EQUAL(i, gh.getCurrentGeneration()); + EXPECT_EQ(i, gh.getCurrentGeneration()); guards.push_back(gh.takeGuard()); // take guard on current generation if (i >= 128) { - EXPECT_EQUAL(i - 128, gh.getFirstUsedGeneration()); + EXPECT_EQ(i - 128, gh.getFirstUsedGeneration()); guards.pop_front(); - EXPECT_EQUAL(128u, gh.getGenerationRefCount()); + EXPECT_EQ(128u, gh.getGenerationRefCount()); } gh.incGeneration(); } } -int -Test::Main() -{ - TEST_INIT("generationhandler_test"); - - TEST_DO(requireThatGenerationCanBeIncreased()); - TEST_DO(requireThatReadersCanTakeGuards()); - TEST_DO(requireThatGuardsCanBeCopied()); - TEST_DO(requireThatTheFirstUsedGenerationIsCorrect()); - TEST_DO(requireThatGenerationCanGrowLarge()); - - TEST_DONE(); -} - } -TEST_APPHOOK(vespalib::Test); +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt b/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt index 569db489e3c..7e5a5af79b5 100644 --- a/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt +++ b/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt @@ -4,5 +4,6 @@ vespa_add_executable(vespalib_generation_handler_stress_test_app generation_handler_stress_test.cpp DEPENDS vespalib + GTest::GTest ) -vespa_add_test(NAME vespalib_generation_handler_stress_test_app COMMAND vespalib_generation_handler_stress_test_app BENCHMARK) +vespa_add_test(NAME vespalib_generation_handler_stress_test_app COMMAND vespalib_generation_handler_stress_test_app --smoke-test) diff --git a/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp b/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp index fa2c525b518..0689909da09 100644 --- a/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp +++ b/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/log/log.h> LOG_SETUP("generation_handler_stress_test"); -#include <vespa/vespalib/testkit/testapp.h> +#include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/util/generationhandler.h> #include <vespa/vespalib/util/threadstackexecutor.h> #include <vespa/vespalib/util/size_literals.h> @@ -10,6 +10,12 @@ using vespalib::Executor; using vespalib::GenerationHandler; using vespalib::ThreadStackExecutor; +namespace { + +bool smoke_test = false; +const vespalib::string smoke_test_option = "--smoke-test"; + +} struct WorkContext { @@ -21,26 +27,28 @@ struct WorkContext } }; -struct Fixture { +class Fixture : public ::testing::Test { +protected: GenerationHandler _generationHandler; uint32_t _readThreads; ThreadStackExecutor _writer; // 1 write thread - ThreadStackExecutor _readers; // multiple reader threads + std::unique_ptr<ThreadStackExecutor> _readers; // multiple reader threads std::atomic<long> _readSeed; std::atomic<long> _doneWriteWork; std::atomic<long> _doneReadWork; std::atomic<int> _stopRead; bool _reportWork; - Fixture(uint32_t readThreads = 1); - + Fixture(); ~Fixture(); - void readWork(const WorkContext &context); - void writeWork(uint32_t cnt, WorkContext &context); + void set_read_threads(uint32_t read_threads); + uint32_t getReadThreads() const { return _readThreads; } void stressTest(uint32_t writeCnt); - +public: + void readWork(const WorkContext &context); + void writeWork(uint32_t cnt, WorkContext &context); private: Fixture(const Fixture &index) = delete; Fixture(Fixture &&index) = delete; @@ -49,23 +57,27 @@ private: }; -Fixture::Fixture(uint32_t readThreads) - : _generationHandler(), - _readThreads(readThreads), +Fixture::Fixture() + : ::testing::Test(), + _generationHandler(), + _readThreads(1), _writer(1, 128_Ki), - _readers(readThreads, 128_Ki), + _readers(), _doneWriteWork(0), _doneReadWork(0), _stopRead(0), _reportWork(false) { + set_read_threads(1); } Fixture::~Fixture() { - _readers.sync(); - _readers.shutdown(); + if (_readers) { + _readers->sync(); + _readers->shutdown(); + } _writer.sync(); _writer.shutdown(); if (_reportWork) { @@ -75,6 +87,16 @@ Fixture::~Fixture() } } +void +Fixture::set_read_threads(uint32_t read_threads) +{ + if (_readers) { + _readers->sync(); + _readers->shutdown(); + } + _readThreads = read_threads; + _readers = std::make_unique<ThreadStackExecutor>(read_threads, 128_Ki); +} void Fixture::readWork(const WorkContext &context) @@ -85,7 +107,7 @@ Fixture::readWork(const WorkContext &context) for (i = 0; i < cnt && _stopRead.load() == 0; ++i) { auto guard = _generationHandler.takeGuard(); auto generation = context._generation.load(std::memory_order_relaxed); - EXPECT_GREATER_EQUAL(generation, guard.getGeneration()); + EXPECT_GE(generation, guard.getGeneration()); } _doneReadWork += i; LOG(info, "done %u read work", i); @@ -150,19 +172,32 @@ Fixture::stressTest(uint32_t writeCnt) auto context = std::make_shared<WorkContext>(); _writer.execute(std::make_unique<WriteWorkTask>(*this, writeCnt, context)); for (uint32_t i = 0; i < readThreads; ++i) { - _readers.execute(std::make_unique<ReadWorkTask>(*this, context)); + _readers->execute(std::make_unique<ReadWorkTask>(*this, context)); } + _writer.sync(); + _readers->sync(); } +using GenerationHandlerStressTest = Fixture; -TEST_F("stress test, 2 readers", Fixture(2)) +TEST_F(GenerationHandlerStressTest, stress_test_2_readers) { - f.stressTest(1000000); + set_read_threads(2); + stressTest(smoke_test ? 10000 : 1000000); } -TEST_F("stress test, 4 readers", Fixture(4)) +TEST_F(GenerationHandlerStressTest, stress_test_4_readers) { - f.stressTest(1000000); + set_read_threads(4); + stressTest(smoke_test ? 10000 : 1000000); } -TEST_MAIN() { TEST_RUN_ALL(); } +int main(int argc, char **argv) { + if (argc > 1 && argv[1] == smoke_test_option) { + smoke_test = true; + ++argv; + --argc; + } + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} |