aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--client/go/vespa/target.go472
-rw-r--r--client/go/vespa/target_cloud.go382
-rw-r--r--client/go/vespa/target_custom.go128
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java1
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java6
-rw-r--r--configdefinitions/src/vespa/dispatch.def3
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java3
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/GroupingResultAggregator.java50
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java17
-rw-r--r--container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java39
-rw-r--r--container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java3
-rw-r--r--flags/src/main/java/com/yahoo/vespa/flags/Flags.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java4
-rw-r--r--vespalib/src/tests/util/generationhandler/CMakeLists.txt1
-rw-r--r--vespalib/src/tests/util/generationhandler/generationhandler_test.cpp154
-rw-r--r--vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt3
-rw-r--r--vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp77
18 files changed, 774 insertions, 578 deletions
diff --git a/client/go/vespa/target.go b/client/go/vespa/target.go
index 2314cc06f33..8d48f1ffa3a 100644
--- a/client/go/vespa/target.go
+++ b/client/go/vespa/target.go
@@ -3,23 +3,15 @@
package vespa
import (
- "bytes"
"crypto/tls"
- "encoding/json"
"fmt"
"io"
"io/ioutil"
- "math"
"net/http"
- "net/url"
- "sort"
- "strconv"
"time"
- "github.com/vespa-engine/vespa/client/go/auth0"
"github.com/vespa-engine/vespa/client/go/util"
"github.com/vespa-engine/vespa/client/go/version"
- "github.com/vespa-engine/vespa/client/go/zts"
)
const (
@@ -94,26 +86,6 @@ type LogOptions struct {
Level int
}
-// CloudOptions configures URL and authentication for a cloud target.
-type APIOptions struct {
- System System
- TLSOptions TLSOptions
- APIKey []byte
- AuthConfigPath string
-}
-
-// CloudDeploymentOptions configures the deployment to manage through a cloud target.
-type CloudDeploymentOptions struct {
- Deployment Deployment
- TLSOptions TLSOptions
- ClusterURLs map[string]string // Endpoints keyed on cluster name
-}
-
-type customTarget struct {
- targetType string
- baseURL string
-}
-
// Do sends request to this service. Any required authentication happens automatically.
func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Response, error) {
if s.TLSOptions.KeyPair.Certificate != nil {
@@ -143,12 +115,7 @@ func (s *Service) Wait(timeout time.Duration) (int, error) {
default:
return 0, fmt.Errorf("invalid service: %s", s.Name)
}
- req, err := http.NewRequest("GET", url, nil)
- if err != nil {
- return 0, err
- }
- okFunc := func(status int, response []byte) (bool, error) { return status/100 == 2, nil }
- return wait(okFunc, func() *http.Request { return req }, &s.TLSOptions.KeyPair, timeout)
+ return waitForOK(url, &s.TLSOptions.KeyPair, timeout)
}
func (s *Service) Description() string {
@@ -163,442 +130,23 @@ func (s *Service) Description() string {
return fmt.Sprintf("No description of service %s", s.Name)
}
-func (t *customTarget) Type() string { return t.targetType }
-
-func (t *customTarget) Deployment() Deployment { return Deployment{} }
-
-func (t *customTarget) Service(name string, timeout time.Duration, sessionOrRunID int64, cluster string) (*Service, error) {
- if timeout > 0 && name != DeployService {
- if err := t.waitForConvergence(timeout); err != nil {
- return nil, err
- }
- }
- switch name {
- case DeployService, QueryService, DocumentService:
- url, err := t.urlWithPort(name)
- if err != nil {
- return nil, err
- }
- return &Service{BaseURL: url, Name: name}, nil
- }
- return nil, fmt.Errorf("unknown service: %s", name)
-}
-
-func (t *customTarget) PrintLog(options LogOptions) error {
- return fmt.Errorf("reading logs from non-cloud deployment is unsupported")
-}
-
-func (t *customTarget) SignRequest(req *http.Request, sigKeyId string) error { return nil }
+func isOK(status int) bool { return status/100 == 2 }
-func (t *customTarget) CheckVersion(version version.Version) error { return nil }
+type responseFunc func(status int, response []byte) (bool, error)
-func (t *customTarget) urlWithPort(serviceName string) (string, error) {
- u, err := url.Parse(t.baseURL)
- if err != nil {
- return "", err
- }
- port := u.Port()
- if port == "" {
- switch serviceName {
- case DeployService:
- port = "19071"
- case QueryService, DocumentService:
- port = "8080"
- default:
- return "", fmt.Errorf("unknown service: %s", serviceName)
- }
- u.Host = u.Host + ":" + port
- }
- return u.String(), nil
-}
+type requestFunc func() *http.Request
-func (t *customTarget) waitForConvergence(timeout time.Duration) error {
- deployer, err := t.Service(DeployService, 0, 0, "")
- if err != nil {
- return err
- }
- url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployer.BaseURL)
+// waitForOK queries url and returns its status code. If the url returns a non-200 status code, it is repeatedly queried
+// until timeout elapses.
+func waitForOK(url string, certificate *tls.Certificate, timeout time.Duration) (int, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
- return err
- }
- converged := false
- convergedFunc := func(status int, response []byte) (bool, error) {
- if status/100 != 2 {
- return false, nil
- }
- var resp serviceConvergeResponse
- if err := json.Unmarshal(response, &resp); err != nil {
- return false, nil
- }
- converged = resp.Converged
- return converged, nil
- }
- if _, err := wait(convergedFunc, func() *http.Request { return req }, nil, timeout); err != nil {
- return err
- }
- if !converged {
- return fmt.Errorf("services have not converged")
- }
- return nil
-}
-
-type cloudTarget struct {
- apiOptions APIOptions
- deploymentOptions CloudDeploymentOptions
- logOptions LogOptions
- ztsClient ztsClient
-}
-
-type ztsClient interface {
- AccessToken(domain string, certficiate tls.Certificate) (string, error)
-}
-
-func (t *cloudTarget) resolveEndpoint(cluster string) (string, error) {
- if cluster == "" {
- for _, u := range t.deploymentOptions.ClusterURLs {
- if len(t.deploymentOptions.ClusterURLs) == 1 {
- return u, nil
- } else {
- return "", fmt.Errorf("multiple clusters, none chosen: %v", t.deploymentOptions.ClusterURLs)
- }
- }
- } else {
- u := t.deploymentOptions.ClusterURLs[cluster]
- if u == "" {
- clusters := make([]string, len(t.deploymentOptions.ClusterURLs))
- for c := range t.deploymentOptions.ClusterURLs {
- clusters = append(clusters, c)
- }
- return "", fmt.Errorf("unknown cluster '%s': must be one of %v", cluster, clusters)
- }
- return u, nil
- }
-
- return "", fmt.Errorf("no endpoints")
-}
-
-func (t *cloudTarget) Type() string {
- switch t.apiOptions.System.Name {
- case MainSystem.Name, CDSystem.Name:
- return TargetHosted
- }
- return TargetCloud
-}
-
-func (t *cloudTarget) Deployment() Deployment { return t.deploymentOptions.Deployment }
-
-func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, cluster string) (*Service, error) {
- if name != DeployService && t.deploymentOptions.ClusterURLs == nil {
- if err := t.waitForEndpoints(timeout, runID); err != nil {
- return nil, err
- }
- }
- switch name {
- case DeployService:
- return &Service{Name: name, BaseURL: t.apiOptions.System.URL, TLSOptions: t.apiOptions.TLSOptions, ztsClient: t.ztsClient}, nil
- case QueryService, DocumentService:
- url, err := t.resolveEndpoint(cluster)
- if err != nil {
- return nil, err
- }
- t.deploymentOptions.TLSOptions.AthenzDomain = t.apiOptions.System.AthenzDomain
- return &Service{Name: name, BaseURL: url, TLSOptions: t.deploymentOptions.TLSOptions, ztsClient: t.ztsClient}, nil
- }
- return nil, fmt.Errorf("unknown service: %s", name)
-}
-
-func (t *cloudTarget) SignRequest(req *http.Request, keyID string) error {
- if t.apiOptions.System.IsPublic() {
- if t.apiOptions.APIKey != nil {
- signer := NewRequestSigner(keyID, t.apiOptions.APIKey)
- return signer.SignRequest(req)
- } else {
- return t.addAuth0AccessToken(req)
- }
- } else {
- if t.apiOptions.TLSOptions.KeyPair.Certificate == nil {
- return fmt.Errorf("system %s requires a certificate for authentication", t.apiOptions.System.Name)
- }
- return nil
- }
-}
-
-func (t *cloudTarget) CheckVersion(clientVersion version.Version) error {
- if clientVersion.IsZero() { // development version is always fine
- return nil
- }
- req, err := http.NewRequest("GET", fmt.Sprintf("%s/cli/v1/", t.apiOptions.System.URL), nil)
- if err != nil {
- return err
- }
- response, err := util.HttpDo(req, 10*time.Second, "")
- if err != nil {
- return err
- }
- defer response.Body.Close()
- var cliResponse struct {
- MinVersion string `json:"minVersion"`
- }
- dec := json.NewDecoder(response.Body)
- if err := dec.Decode(&cliResponse); err != nil {
- return err
- }
- minVersion, err := version.Parse(cliResponse.MinVersion)
- if err != nil {
- return err
- }
- if clientVersion.Less(minVersion) {
- return fmt.Errorf("client version %s is less than the minimum supported version: %s", clientVersion, minVersion)
- }
- return nil
-}
-
-func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error {
- a, err := auth0.GetAuth0(t.apiOptions.AuthConfigPath, t.apiOptions.System.Name, t.apiOptions.System.URL)
- if err != nil {
- return err
- }
- system, err := a.PrepareSystem(auth0.ContextWithCancel())
- if err != nil {
- return err
- }
- request.Header.Set("Authorization", "Bearer "+system.AccessToken)
- return nil
-}
-
-func (t *cloudTarget) logsURL() string {
- return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs",
- t.apiOptions.System.URL,
- t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance,
- t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region)
-}
-
-func (t *cloudTarget) PrintLog(options LogOptions) error {
- req, err := http.NewRequest("GET", t.logsURL(), nil)
- if err != nil {
- return err
- }
- lastFrom := options.From
- requestFunc := func() *http.Request {
- fromMillis := lastFrom.Unix() * 1000
- q := req.URL.Query()
- q.Set("from", strconv.FormatInt(fromMillis, 10))
- if !options.To.IsZero() {
- toMillis := options.To.Unix() * 1000
- q.Set("to", strconv.FormatInt(toMillis, 10))
- }
- req.URL.RawQuery = q.Encode()
- t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm())
- return req
- }
- logFunc := func(status int, response []byte) (bool, error) {
- if ok, err := isOK(status); !ok {
- return ok, err
- }
- logEntries, err := ReadLogEntries(bytes.NewReader(response))
- if err != nil {
- return true, err
- }
- for _, le := range logEntries {
- if !le.Time.After(lastFrom) {
- continue
- }
- if LogLevel(le.Level) > options.Level {
- continue
- }
- fmt.Fprintln(options.Writer, le.Format(options.Dequote))
- }
- if len(logEntries) > 0 {
- lastFrom = logEntries[len(logEntries)-1].Time
- }
- return false, nil
- }
- var timeout time.Duration
- if options.Follow {
- timeout = math.MaxInt64 // No timeout
- }
- _, err = wait(logFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout)
- return err
-}
-
-func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error {
- if runID > 0 {
- if err := t.waitForRun(runID, timeout); err != nil {
- return err
- }
- }
- return t.discoverEndpoints(timeout)
-}
-
-func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error {
- runURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/job/%s-%s/run/%d",
- t.apiOptions.System.URL,
- t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance,
- t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region, runID)
- req, err := http.NewRequest("GET", runURL, nil)
- if err != nil {
- return err
- }
- lastID := int64(-1)
- requestFunc := func() *http.Request {
- q := req.URL.Query()
- q.Set("after", strconv.FormatInt(lastID, 10))
- req.URL.RawQuery = q.Encode()
- if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil {
- panic(err)
- }
- return req
- }
- jobSuccessFunc := func(status int, response []byte) (bool, error) {
- if ok, err := isOK(status); !ok {
- return ok, err
- }
- var resp jobResponse
- if err := json.Unmarshal(response, &resp); err != nil {
- return false, nil
- }
- if t.logOptions.Writer != nil {
- lastID = t.printLog(resp, lastID)
- }
- if resp.Active {
- return false, nil
- }
- if resp.Status != "success" {
- return false, fmt.Errorf("run %d ended with unsuccessful status: %s", runID, resp.Status)
- }
- return true, nil
- }
- _, err = wait(jobSuccessFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout)
- return err
-}
-
-func (t *cloudTarget) printLog(response jobResponse, last int64) int64 {
- if response.LastID == 0 {
- return last
- }
- var msgs []logMessage
- for step, stepMsgs := range response.Log {
- for _, msg := range stepMsgs {
- if step == "copyVespaLogs" && LogLevel(msg.Type) > t.logOptions.Level || LogLevel(msg.Type) == 3 {
- continue
- }
- msgs = append(msgs, msg)
- }
- }
- sort.Slice(msgs, func(i, j int) bool { return msgs[i].At < msgs[j].At })
- for _, msg := range msgs {
- tm := time.Unix(msg.At/1000, (msg.At%1000)*1000)
- fmtTime := tm.Format("15:04:05")
- fmt.Fprintf(t.logOptions.Writer, "[%s] %-7s %s\n", fmtTime, msg.Type, msg.Message)
- }
- return response.LastID
-}
-
-func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error {
- deploymentURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s",
- t.apiOptions.System.URL,
- t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance,
- t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region)
- req, err := http.NewRequest("GET", deploymentURL, nil)
- if err != nil {
- return err
- }
- if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil {
- return err
- }
- urlsByCluster := make(map[string]string)
- endpointFunc := func(status int, response []byte) (bool, error) {
- if ok, err := isOK(status); !ok {
- return ok, err
- }
- var resp deploymentResponse
- if err := json.Unmarshal(response, &resp); err != nil {
- return false, nil
- }
- if len(resp.Endpoints) == 0 {
- return false, nil
- }
- for _, endpoint := range resp.Endpoints {
- if endpoint.Scope != "zone" {
- continue
- }
- urlsByCluster[endpoint.Cluster] = endpoint.URL
- }
- return true, nil
- }
- if _, err = wait(endpointFunc, func() *http.Request { return req }, &t.apiOptions.TLSOptions.KeyPair, timeout); err != nil {
- return err
- }
- if len(urlsByCluster) == 0 {
- return fmt.Errorf("no endpoints discovered")
- }
- t.deploymentOptions.ClusterURLs = urlsByCluster
- return nil
-}
-
-func isOK(status int) (bool, error) {
- if status == 401 {
- return false, fmt.Errorf("status %d: invalid api key", status)
- }
- return status/100 == 2, nil
-}
-
-// LocalTarget creates a target for a Vespa platform running locally.
-func LocalTarget() Target {
- return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1"}
-}
-
-// CustomTarget creates a Target for a Vespa platform running at baseURL.
-func CustomTarget(baseURL string) Target {
- return &customTarget{targetType: TargetCustom, baseURL: baseURL}
-}
-
-// CloudTarget creates a Target for the Vespa Cloud or hosted Vespa platform.
-func CloudTarget(apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) {
- ztsClient, err := zts.NewClient(zts.DefaultURL, util.ActiveHttpClient)
- if err != nil {
- return nil, err
+ return 0, err
}
- return &cloudTarget{
- apiOptions: apiOptions,
- deploymentOptions: deploymentOptions,
- logOptions: logOptions,
- ztsClient: ztsClient,
- }, nil
-}
-
-type deploymentEndpoint struct {
- Cluster string `json:"cluster"`
- URL string `json:"url"`
- Scope string `json:"scope"`
-}
-
-type deploymentResponse struct {
- Endpoints []deploymentEndpoint `json:"endpoints"`
+ okFunc := func(status int, response []byte) (bool, error) { return isOK(status), nil }
+ return wait(okFunc, func() *http.Request { return req }, certificate, timeout)
}
-type serviceConvergeResponse struct {
- Converged bool `json:"converged"`
-}
-
-type jobResponse struct {
- Active bool `json:"active"`
- Status string `json:"status"`
- Log map[string][]logMessage `json:"log"`
- LastID int64 `json:"lastId"`
-}
-
-type logMessage struct {
- At int64 `json:"at"`
- Type string `json:"type"`
- Message string `json:"message"`
-}
-
-type responseFunc func(status int, response []byte) (bool, error)
-
-type requestFunc func() *http.Request
-
func wait(fn responseFunc, reqFn requestFunc, certificate *tls.Certificate, timeout time.Duration) (int, error) {
if certificate != nil {
util.ActiveHttpClient.UseCertificate([]tls.Certificate{*certificate})
diff --git a/client/go/vespa/target_cloud.go b/client/go/vespa/target_cloud.go
new file mode 100644
index 00000000000..f4eccaacab6
--- /dev/null
+++ b/client/go/vespa/target_cloud.go
@@ -0,0 +1,382 @@
+package vespa
+
+import (
+ "bytes"
+ "crypto/tls"
+ "encoding/json"
+ "fmt"
+ "math"
+ "net/http"
+ "sort"
+ "strconv"
+ "time"
+
+ "github.com/vespa-engine/vespa/client/go/auth0"
+ "github.com/vespa-engine/vespa/client/go/util"
+ "github.com/vespa-engine/vespa/client/go/version"
+ "github.com/vespa-engine/vespa/client/go/zts"
+)
+
+// CloudOptions configures URL and authentication for a cloud target.
+type APIOptions struct {
+ System System
+ TLSOptions TLSOptions
+ APIKey []byte
+ AuthConfigPath string
+}
+
+// CloudDeploymentOptions configures the deployment to manage through a cloud target.
+type CloudDeploymentOptions struct {
+ Deployment Deployment
+ TLSOptions TLSOptions
+ ClusterURLs map[string]string // Endpoints keyed on cluster name
+}
+
+type cloudTarget struct {
+ apiOptions APIOptions
+ deploymentOptions CloudDeploymentOptions
+ logOptions LogOptions
+ ztsClient ztsClient
+}
+
+type deploymentEndpoint struct {
+ Cluster string `json:"cluster"`
+ URL string `json:"url"`
+ Scope string `json:"scope"`
+}
+
+type deploymentResponse struct {
+ Endpoints []deploymentEndpoint `json:"endpoints"`
+}
+
+type jobResponse struct {
+ Active bool `json:"active"`
+ Status string `json:"status"`
+ Log map[string][]logMessage `json:"log"`
+ LastID int64 `json:"lastId"`
+}
+
+type logMessage struct {
+ At int64 `json:"at"`
+ Type string `json:"type"`
+ Message string `json:"message"`
+}
+
+type ztsClient interface {
+ AccessToken(domain string, certficiate tls.Certificate) (string, error)
+}
+
+// CloudTarget creates a Target for the Vespa Cloud or hosted Vespa platform.
+func CloudTarget(apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) {
+ ztsClient, err := zts.NewClient(zts.DefaultURL, util.ActiveHttpClient)
+ if err != nil {
+ return nil, err
+ }
+ return &cloudTarget{
+ apiOptions: apiOptions,
+ deploymentOptions: deploymentOptions,
+ logOptions: logOptions,
+ ztsClient: ztsClient,
+ }, nil
+}
+
+func (t *cloudTarget) resolveEndpoint(cluster string) (string, error) {
+ if cluster == "" {
+ for _, u := range t.deploymentOptions.ClusterURLs {
+ if len(t.deploymentOptions.ClusterURLs) == 1 {
+ return u, nil
+ } else {
+ return "", fmt.Errorf("multiple clusters, none chosen: %v", t.deploymentOptions.ClusterURLs)
+ }
+ }
+ } else {
+ u := t.deploymentOptions.ClusterURLs[cluster]
+ if u == "" {
+ clusters := make([]string, len(t.deploymentOptions.ClusterURLs))
+ for c := range t.deploymentOptions.ClusterURLs {
+ clusters = append(clusters, c)
+ }
+ return "", fmt.Errorf("unknown cluster '%s': must be one of %v", cluster, clusters)
+ }
+ return u, nil
+ }
+
+ return "", fmt.Errorf("no endpoints")
+}
+
+func (t *cloudTarget) Type() string {
+ switch t.apiOptions.System.Name {
+ case MainSystem.Name, CDSystem.Name:
+ return TargetHosted
+ }
+ return TargetCloud
+}
+
+func (t *cloudTarget) Deployment() Deployment { return t.deploymentOptions.Deployment }
+
+func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, cluster string) (*Service, error) {
+ switch name {
+ case DeployService:
+ service := &Service{Name: name, BaseURL: t.apiOptions.System.URL, TLSOptions: t.apiOptions.TLSOptions, ztsClient: t.ztsClient}
+ if timeout > 0 {
+ status, err := service.Wait(timeout)
+ if err != nil {
+ return nil, err
+ }
+ if !isOK(status) {
+ return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL)
+ }
+ }
+ return service, nil
+ case QueryService, DocumentService:
+ if t.deploymentOptions.ClusterURLs == nil {
+ if err := t.waitForEndpoints(timeout, runID); err != nil {
+ return nil, err
+ }
+ }
+ url, err := t.resolveEndpoint(cluster)
+ if err != nil {
+ return nil, err
+ }
+ t.deploymentOptions.TLSOptions.AthenzDomain = t.apiOptions.System.AthenzDomain
+ return &Service{Name: name, BaseURL: url, TLSOptions: t.deploymentOptions.TLSOptions, ztsClient: t.ztsClient}, nil
+ }
+ return nil, fmt.Errorf("unknown service: %s", name)
+}
+
+func (t *cloudTarget) SignRequest(req *http.Request, keyID string) error {
+ if t.apiOptions.System.IsPublic() {
+ if t.apiOptions.APIKey != nil {
+ signer := NewRequestSigner(keyID, t.apiOptions.APIKey)
+ return signer.SignRequest(req)
+ } else {
+ return t.addAuth0AccessToken(req)
+ }
+ } else {
+ if t.apiOptions.TLSOptions.KeyPair.Certificate == nil {
+ return fmt.Errorf("system %s requires a certificate for authentication", t.apiOptions.System.Name)
+ }
+ return nil
+ }
+}
+
+func (t *cloudTarget) CheckVersion(clientVersion version.Version) error {
+ if clientVersion.IsZero() { // development version is always fine
+ return nil
+ }
+ req, err := http.NewRequest("GET", fmt.Sprintf("%s/cli/v1/", t.apiOptions.System.URL), nil)
+ if err != nil {
+ return err
+ }
+ response, err := util.HttpDo(req, 10*time.Second, "")
+ if err != nil {
+ return err
+ }
+ defer response.Body.Close()
+ var cliResponse struct {
+ MinVersion string `json:"minVersion"`
+ }
+ dec := json.NewDecoder(response.Body)
+ if err := dec.Decode(&cliResponse); err != nil {
+ return err
+ }
+ minVersion, err := version.Parse(cliResponse.MinVersion)
+ if err != nil {
+ return err
+ }
+ if clientVersion.Less(minVersion) {
+ return fmt.Errorf("client version %s is less than the minimum supported version: %s", clientVersion, minVersion)
+ }
+ return nil
+}
+
+func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error {
+ a, err := auth0.GetAuth0(t.apiOptions.AuthConfigPath, t.apiOptions.System.Name, t.apiOptions.System.URL)
+ if err != nil {
+ return err
+ }
+ system, err := a.PrepareSystem(auth0.ContextWithCancel())
+ if err != nil {
+ return err
+ }
+ request.Header.Set("Authorization", "Bearer "+system.AccessToken)
+ return nil
+}
+
+func (t *cloudTarget) logsURL() string {
+ return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs",
+ t.apiOptions.System.URL,
+ t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance,
+ t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region)
+}
+
+func (t *cloudTarget) PrintLog(options LogOptions) error {
+ req, err := http.NewRequest("GET", t.logsURL(), nil)
+ if err != nil {
+ return err
+ }
+ lastFrom := options.From
+ requestFunc := func() *http.Request {
+ fromMillis := lastFrom.Unix() * 1000
+ q := req.URL.Query()
+ q.Set("from", strconv.FormatInt(fromMillis, 10))
+ if !options.To.IsZero() {
+ toMillis := options.To.Unix() * 1000
+ q.Set("to", strconv.FormatInt(toMillis, 10))
+ }
+ req.URL.RawQuery = q.Encode()
+ t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm())
+ return req
+ }
+ logFunc := func(status int, response []byte) (bool, error) {
+ if ok, err := isCloudOK(status); !ok {
+ return ok, err
+ }
+ logEntries, err := ReadLogEntries(bytes.NewReader(response))
+ if err != nil {
+ return true, err
+ }
+ for _, le := range logEntries {
+ if !le.Time.After(lastFrom) {
+ continue
+ }
+ if LogLevel(le.Level) > options.Level {
+ continue
+ }
+ fmt.Fprintln(options.Writer, le.Format(options.Dequote))
+ }
+ if len(logEntries) > 0 {
+ lastFrom = logEntries[len(logEntries)-1].Time
+ }
+ return false, nil
+ }
+ var timeout time.Duration
+ if options.Follow {
+ timeout = math.MaxInt64 // No timeout
+ }
+ _, err = wait(logFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout)
+ return err
+}
+
+func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error {
+ if runID > 0 {
+ if err := t.waitForRun(runID, timeout); err != nil {
+ return err
+ }
+ }
+ return t.discoverEndpoints(timeout)
+}
+
+func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error {
+ runURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/job/%s-%s/run/%d",
+ t.apiOptions.System.URL,
+ t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance,
+ t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region, runID)
+ req, err := http.NewRequest("GET", runURL, nil)
+ if err != nil {
+ return err
+ }
+ lastID := int64(-1)
+ requestFunc := func() *http.Request {
+ q := req.URL.Query()
+ q.Set("after", strconv.FormatInt(lastID, 10))
+ req.URL.RawQuery = q.Encode()
+ if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil {
+ panic(err)
+ }
+ return req
+ }
+ jobSuccessFunc := func(status int, response []byte) (bool, error) {
+ if ok, err := isCloudOK(status); !ok {
+ return ok, err
+ }
+ var resp jobResponse
+ if err := json.Unmarshal(response, &resp); err != nil {
+ return false, nil
+ }
+ if t.logOptions.Writer != nil {
+ lastID = t.printLog(resp, lastID)
+ }
+ if resp.Active {
+ return false, nil
+ }
+ if resp.Status != "success" {
+ return false, fmt.Errorf("run %d ended with unsuccessful status: %s", runID, resp.Status)
+ }
+ return true, nil
+ }
+ _, err = wait(jobSuccessFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout)
+ return err
+}
+
+func (t *cloudTarget) printLog(response jobResponse, last int64) int64 {
+ if response.LastID == 0 {
+ return last
+ }
+ var msgs []logMessage
+ for step, stepMsgs := range response.Log {
+ for _, msg := range stepMsgs {
+ if step == "copyVespaLogs" && LogLevel(msg.Type) > t.logOptions.Level || LogLevel(msg.Type) == 3 {
+ continue
+ }
+ msgs = append(msgs, msg)
+ }
+ }
+ sort.Slice(msgs, func(i, j int) bool { return msgs[i].At < msgs[j].At })
+ for _, msg := range msgs {
+ tm := time.Unix(msg.At/1000, (msg.At%1000)*1000)
+ fmtTime := tm.Format("15:04:05")
+ fmt.Fprintf(t.logOptions.Writer, "[%s] %-7s %s\n", fmtTime, msg.Type, msg.Message)
+ }
+ return response.LastID
+}
+
+func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error {
+ deploymentURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s",
+ t.apiOptions.System.URL,
+ t.deploymentOptions.Deployment.Application.Tenant, t.deploymentOptions.Deployment.Application.Application, t.deploymentOptions.Deployment.Application.Instance,
+ t.deploymentOptions.Deployment.Zone.Environment, t.deploymentOptions.Deployment.Zone.Region)
+ req, err := http.NewRequest("GET", deploymentURL, nil)
+ if err != nil {
+ return err
+ }
+ if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil {
+ return err
+ }
+ urlsByCluster := make(map[string]string)
+ endpointFunc := func(status int, response []byte) (bool, error) {
+ if ok, err := isCloudOK(status); !ok {
+ return ok, err
+ }
+ var resp deploymentResponse
+ if err := json.Unmarshal(response, &resp); err != nil {
+ return false, nil
+ }
+ if len(resp.Endpoints) == 0 {
+ return false, nil
+ }
+ for _, endpoint := range resp.Endpoints {
+ if endpoint.Scope != "zone" {
+ continue
+ }
+ urlsByCluster[endpoint.Cluster] = endpoint.URL
+ }
+ return true, nil
+ }
+ if _, err = wait(endpointFunc, func() *http.Request { return req }, &t.apiOptions.TLSOptions.KeyPair, timeout); err != nil {
+ return err
+ }
+ if len(urlsByCluster) == 0 {
+ return fmt.Errorf("no endpoints discovered")
+ }
+ t.deploymentOptions.ClusterURLs = urlsByCluster
+ return nil
+}
+
+func isCloudOK(status int) (bool, error) {
+ if status == 401 {
+ // when retrying we should give up immediately if we're not authorized
+ return false, fmt.Errorf("status %d: invalid credentials", status)
+ }
+ return isOK(status), nil
+}
diff --git a/client/go/vespa/target_custom.go b/client/go/vespa/target_custom.go
new file mode 100644
index 00000000000..072ec8649e4
--- /dev/null
+++ b/client/go/vespa/target_custom.go
@@ -0,0 +1,128 @@
+package vespa
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "time"
+
+ "github.com/vespa-engine/vespa/client/go/version"
+)
+
+type customTarget struct {
+ targetType string
+ baseURL string
+}
+
+type serviceConvergeResponse struct {
+ Converged bool `json:"converged"`
+}
+
+// LocalTarget creates a target for a Vespa platform running locally.
+func LocalTarget() Target {
+ return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1"}
+}
+
+// CustomTarget creates a Target for a Vespa platform running at baseURL.
+func CustomTarget(baseURL string) Target {
+ return &customTarget{targetType: TargetCustom, baseURL: baseURL}
+}
+
+func (t *customTarget) Type() string { return t.targetType }
+
+func (t *customTarget) Deployment() Deployment { return Deployment{} }
+
+func (t *customTarget) createService(name string) (*Service, error) {
+ switch name {
+ case DeployService, QueryService, DocumentService:
+ url, err := t.urlWithPort(name)
+ if err != nil {
+ return nil, err
+ }
+ return &Service{BaseURL: url, Name: name}, nil
+ }
+ return nil, fmt.Errorf("unknown service: %s", name)
+}
+
+func (t *customTarget) Service(name string, timeout time.Duration, sessionOrRunID int64, cluster string) (*Service, error) {
+ service, err := t.createService(name)
+ if err != nil {
+ return nil, err
+ }
+ if timeout > 0 {
+ if name == DeployService {
+ status, err := service.Wait(timeout)
+ if err != nil {
+ return nil, err
+ }
+ if !isOK(status) {
+ return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL)
+ }
+ } else {
+ if err := t.waitForConvergence(timeout); err != nil {
+ return nil, err
+ }
+ }
+ }
+ return service, nil
+}
+
+func (t *customTarget) PrintLog(options LogOptions) error {
+ return fmt.Errorf("reading logs from non-cloud deployment is unsupported")
+}
+
+func (t *customTarget) SignRequest(req *http.Request, sigKeyId string) error { return nil }
+
+func (t *customTarget) CheckVersion(version version.Version) error { return nil }
+
+func (t *customTarget) urlWithPort(serviceName string) (string, error) {
+ u, err := url.Parse(t.baseURL)
+ if err != nil {
+ return "", err
+ }
+ port := u.Port()
+ if port == "" {
+ switch serviceName {
+ case DeployService:
+ port = "19071"
+ case QueryService, DocumentService:
+ port = "8080"
+ default:
+ return "", fmt.Errorf("unknown service: %s", serviceName)
+ }
+ u.Host = u.Host + ":" + port
+ }
+ return u.String(), nil
+}
+
+func (t *customTarget) waitForConvergence(timeout time.Duration) error {
+ deployURL, err := t.urlWithPort(DeployService)
+ if err != nil {
+ return err
+ }
+ url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployURL)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return err
+ }
+ converged := false
+ convergedFunc := func(status int, response []byte) (bool, error) {
+ if !isOK(status) {
+ return false, nil
+ }
+ var resp serviceConvergeResponse
+ if err := json.Unmarshal(response, &resp); err != nil {
+ return false, nil
+ }
+ converged = resp.Converged
+ return converged, nil
+ }
+ if _, err := wait(convergedFunc, func() *http.Request { return req }, nil, timeout); err != nil {
+ return err
+ }
+ if !converged {
+ return fmt.Errorf("services have not converged")
+ }
+ return nil
+}
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
index e0e4318ccc8..7cb36374568 100644
--- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
@@ -116,6 +116,7 @@ public interface ModelContext {
@ModelFeatureFlag(owners = {"arnej"}) default boolean useQrserverServiceName() { return true; }
@ModelFeatureFlag(owners = {"bjorncs", "baldersheim"}) default boolean enableJdiscPreshutdownCommand() { return true; }
@ModelFeatureFlag(owners = {"arnej"}) default boolean avoidRenamingSummaryFeatures() { return false; }
+ @ModelFeatureFlag(owners = {"bjorncs", "baldersheim"}) default boolean mergeGroupingResultInSearchInvoker() { return false; }
}
/** Warning: As elsewhere in this package, do not make backwards incompatible changes that will break old config models! */
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java
index 22b752777e9..e483351a25a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java
@@ -158,7 +158,7 @@ public class ContentSearchCluster extends AbstractConfigProducer<SearchCluster>
String clusterName, ContentSearchCluster search) {
List<ModelElement> indexedDefs = getIndexedSchemas(clusterElem);
if (!indexedDefs.isEmpty()) {
- IndexedSearchCluster isc = new IndexedSearchCluster(search, clusterName, 0);
+ IndexedSearchCluster isc = new IndexedSearchCluster(deployState, search, clusterName, 0);
isc.setRoutingSelector(clusterElem.childAsString("documents.selection"));
Double visibilityDelay = clusterElem.childAsDouble("engine.proton.visibility-delay");
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java
index fb7c6696b54..53aac23135a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java
@@ -6,7 +6,6 @@ import com.yahoo.config.model.producer.AbstractConfigProducer;
import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig;
import com.yahoo.search.config.IndexInfoConfig;
import com.yahoo.searchdefinition.DocumentOnlySchema;
-import com.yahoo.searchdefinition.Schema;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.vespa.config.search.AttributesConfig;
import com.yahoo.vespa.config.search.DispatchConfig;
@@ -34,6 +33,7 @@ public class IndexedSearchCluster extends SearchCluster
IlscriptsConfig.Producer,
DispatchConfig.Producer {
+ private final boolean mergeGroupingResultInSearchInvoker;
private String indexingClusterName = null; // The name of the docproc cluster to run indexing, by config.
private String indexingChainName = null;
@@ -63,10 +63,11 @@ public class IndexedSearchCluster extends SearchCluster
return routingSelector;
}
- public IndexedSearchCluster(AbstractConfigProducer<SearchCluster> parent, String clusterName, int index) {
+ public IndexedSearchCluster(DeployState deployState, AbstractConfigProducer<SearchCluster> parent, String clusterName, int index) {
super(parent, clusterName, index);
unionCfg = new UnionConfiguration(this, documentDbs);
rootDispatch = new DispatchGroup(this);
+ mergeGroupingResultInSearchInvoker = deployState.featureFlags().mergeGroupingResultInSearchInvoker();
}
@Override
@@ -320,6 +321,7 @@ public class IndexedSearchCluster extends SearchCluster
builder.maxWaitAfterCoverageFactor(searchCoverage.getMaxWaitAfterCoverageFactor());
}
builder.warmuptime(5.0);
+ builder.mergeGroupingResultInSearchInvokerEnabled(mergeGroupingResultInSearchInvoker);
}
@Override
diff --git a/configdefinitions/src/vespa/dispatch.def b/configdefinitions/src/vespa/dispatch.def
index 17f42a73bfd..fef9300a410 100644
--- a/configdefinitions/src/vespa/dispatch.def
+++ b/configdefinitions/src/vespa/dispatch.def
@@ -71,3 +71,6 @@ node[].host string
# The rpc port of this search node
node[].port int
+
+# Temporary feature flag
+mergeGroupingResultInSearchInvokerEnabled bool default=false
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
index 5a813c7886a..6222e7b1788 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
@@ -209,6 +209,7 @@ public class ModelContextImpl implements ModelContext {
private final boolean inhibitDefaultMergesWhenGlobalMergesPending;
private final boolean useQrserverServiceName;
private final boolean avoidRenamingSummaryFeatures;
+ private final boolean mergeGroupingResultInSearchInvoker;
public FeatureFlags(FlagSource source, ApplicationId appId) {
this.defaultTermwiseLimit = flagValue(source, appId, Flags.DEFAULT_TERM_WISE_LIMIT);
@@ -256,6 +257,7 @@ public class ModelContextImpl implements ModelContext {
this.inhibitDefaultMergesWhenGlobalMergesPending = flagValue(source, appId, Flags.INHIBIT_DEFAULT_MERGES_WHEN_GLOBAL_MERGES_PENDING);
this.useQrserverServiceName = flagValue(source, appId, Flags.USE_QRSERVER_SERVICE_NAME);
this.avoidRenamingSummaryFeatures = flagValue(source, appId, Flags.AVOID_RENAMING_SUMMARY_FEATURES);
+ this.mergeGroupingResultInSearchInvoker = flagValue(source, appId, Flags.MERGE_GROUPING_RESULT_IN_SEARCH_INVOKER);
}
@Override public double defaultTermwiseLimit() { return defaultTermwiseLimit; }
@@ -305,6 +307,7 @@ public class ModelContextImpl implements ModelContext {
@Override public boolean inhibitDefaultMergesWhenGlobalMergesPending() { return inhibitDefaultMergesWhenGlobalMergesPending; }
@Override public boolean useQrserverServiceName() { return useQrserverServiceName; }
@Override public boolean avoidRenamingSummaryFeatures() { return avoidRenamingSummaryFeatures; }
+ @Override public boolean mergeGroupingResultInSearchInvoker() { return mergeGroupingResultInSearchInvoker; }
private static <V> V flagValue(FlagSource source, ApplicationId appId, UnboundFlag<? extends V, ?, ?> flag) {
return flag.bindTo(source)
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/GroupingResultAggregator.java b/container-search/src/main/java/com/yahoo/search/dispatch/GroupingResultAggregator.java
new file mode 100644
index 00000000000..5ce7accfdd4
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/GroupingResultAggregator.java
@@ -0,0 +1,50 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.dispatch;
+
+import com.yahoo.prelude.fastsearch.DocsumDefinitionSet;
+import com.yahoo.prelude.fastsearch.GroupingListHit;
+import com.yahoo.searchlib.aggregation.Grouping;
+
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * Incrementally merges underlying {@link Grouping} instances from {@link GroupingListHit} hits.
+ *
+ * @author bjorncs
+ */
+class GroupingResultAggregator {
+ private static final Logger log = Logger.getLogger(GroupingResultAggregator.class.getName());
+
+ private final Map<Integer, Grouping> groupings = new LinkedHashMap<>();
+ private DocsumDefinitionSet documentDefinitions = null;
+ private int groupingHitsMerged = 0;
+
+ void mergeWith(GroupingListHit result) {
+ if (groupingHitsMerged == 0) documentDefinitions = result.getDocsumDefinitionSet();
+ ++groupingHitsMerged;
+ log.log(Level.FINE, () ->
+ String.format("Merging hit #%d having %d groupings",
+ groupingHitsMerged, result.getGroupingList().size()));
+ for (Grouping grouping : result.getGroupingList()) {
+ groupings.merge(grouping.getId(), grouping, (existingGrouping, newGrouping) -> {
+ existingGrouping.merge(newGrouping);
+ return existingGrouping;
+ });
+ }
+ }
+
+ Optional<GroupingListHit> toAggregatedHit() {
+ if (groupingHitsMerged == 0) return Optional.empty();
+ log.log(Level.FINE, () ->
+ String.format("Creating aggregated hit containing %d groupings from %d hits",
+ groupings.size(), groupingHitsMerged));
+ groupings.values().forEach(Grouping::postMerge);
+ return Optional.of(new GroupingListHit(List.copyOf(groupings.values()), documentDefinitions));
+ }
+
+}
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java
index 4e658122cdf..d7c9f1dce53 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.dispatch;
+import com.yahoo.prelude.fastsearch.GroupingListHit;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.dispatch.searchcluster.Group;
@@ -44,6 +45,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM
private final Group group;
private final LinkedBlockingQueue<SearchInvoker> availableForProcessing;
private final Set<Integer> alreadyFailedNodes;
+ private final boolean mergeGroupingResult;
private Query query;
private boolean adaptiveTimeoutCalculated = false;
@@ -71,6 +73,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM
this.group = group;
this.availableForProcessing = newQueue();
this.alreadyFailedNodes = alreadyFailedNodes;
+ this.mergeGroupingResult = searchCluster.dispatchConfig().mergeGroupingResultInSearchInvokerEnabled();
}
/**
@@ -115,6 +118,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM
long nextTimeout = query.getTimeLeft();
boolean extraDebug = (query.getOffset() == 0) && (query.getHits() == 7) && log.isLoggable(java.util.logging.Level.FINE);
List<InvokerResult> processed = new ArrayList<>();
+ var groupingResultAggregator = new GroupingResultAggregator();
try {
while (!invokers.isEmpty() && nextTimeout >= 0) {
SearchInvoker invoker = availableForProcessing.poll(nextTimeout, TimeUnit.MILLISECONDS);
@@ -126,7 +130,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM
if (extraDebug) {
processed.add(toMerge);
}
- merged = mergeResult(result.getResult(), toMerge, merged);
+ merged = mergeResult(result.getResult(), toMerge, merged, groupingResultAggregator);
ejectInvoker(invoker);
}
nextTimeout = nextTimeout();
@@ -134,6 +138,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM
} catch (InterruptedException e) {
throw new RuntimeException("Interrupted while waiting for search results", e);
}
+ groupingResultAggregator.toAggregatedHit().ifPresent(h -> result.getResult().hits().add(h));
insertNetworkErrors(result.getResult());
result.getResult().setCoverage(createCoverage());
@@ -238,14 +243,20 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM
return nextAdaptive;
}
- private List<LeanHit> mergeResult(Result result, InvokerResult partialResult, List<LeanHit> current) {
+ private List<LeanHit> mergeResult(Result result, InvokerResult partialResult, List<LeanHit> current,
+ GroupingResultAggregator groupingResultAggregator) {
collectCoverage(partialResult.getResult().getCoverage(true));
result.mergeWith(partialResult.getResult());
List<Hit> partialNonLean = partialResult.getResult().hits().asUnorderedHits();
for(Hit hit : partialNonLean) {
if (hit.isAuxiliary()) {
- result.hits().add(hit);
+ if (hit instanceof GroupingListHit && mergeGroupingResult) {
+ var groupingHit = (GroupingListHit) hit;
+ groupingResultAggregator.mergeWith(groupingHit);
+ } else {
+ result.hits().add(hit);
+ }
}
}
if (current.isEmpty() ) {
diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java
index 6e08b1c6fa5..347276d680d 100644
--- a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java
+++ b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java
@@ -4,6 +4,7 @@ package com.yahoo.search.dispatch;
import com.yahoo.document.GlobalId;
import com.yahoo.document.idstring.IdString;
import com.yahoo.prelude.fastsearch.FastHit;
+import com.yahoo.prelude.fastsearch.GroupingListHit;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.dispatch.searchcluster.Group;
@@ -13,6 +14,11 @@ import com.yahoo.search.result.Coverage;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.Relevance;
+import com.yahoo.searchlib.aggregation.Grouping;
+import com.yahoo.searchlib.aggregation.MaxAggregationResult;
+import com.yahoo.searchlib.aggregation.MinAggregationResult;
+import com.yahoo.searchlib.expression.IntegerResultNode;
+import com.yahoo.searchlib.expression.StringResultNode;
import com.yahoo.test.ManualClock;
import org.junit.Test;
@@ -320,6 +326,39 @@ public class InterleavedSearchInvokerTest {
assertEquals(3, result.getQuery().getHits());
}
+ @Test
+ public void requireThatGroupingsAreMerged() throws IOException {
+ SearchCluster cluster = new MockSearchCluster("!", 1, 2);
+ List<SearchInvoker> invokers = new ArrayList<>();
+
+ Grouping grouping1 = new Grouping(0);
+ grouping1.setRoot(new com.yahoo.searchlib.aggregation.Group()
+ .addChild(new com.yahoo.searchlib.aggregation.Group()
+ .setId(new StringResultNode("uniqueA"))
+ .addAggregationResult(new MaxAggregationResult().setMax(new IntegerResultNode(6)).setTag(4)))
+ .addChild(new com.yahoo.searchlib.aggregation.Group()
+ .setId(new StringResultNode("common"))
+ .addAggregationResult(new MaxAggregationResult().setMax(new IntegerResultNode(9)).setTag(4))));
+ invokers.add(new MockInvoker(0).setHits(List.of(new GroupingListHit(List.of(grouping1)))));
+
+ Grouping grouping2 = new Grouping(0);
+ grouping2.setRoot(new com.yahoo.searchlib.aggregation.Group()
+ .addChild(new com.yahoo.searchlib.aggregation.Group()
+ .setId(new StringResultNode("uniqueB"))
+ .addAggregationResult(new MaxAggregationResult().setMax(new IntegerResultNode(9)).setTag(4)))
+ .addChild(new com.yahoo.searchlib.aggregation.Group()
+ .setId(new StringResultNode("common"))
+ .addAggregationResult(new MinAggregationResult().setMin(new IntegerResultNode(6)).setTag(3))));
+ invokers.add(new MockInvoker(0).setHits(List.of(new GroupingListHit(List.of(grouping2)))));
+
+ InterleavedSearchInvoker invoker = new InterleavedSearchInvoker(invokers, cluster, new Group(0, List.of()), Collections.emptySet());
+ invoker.responseAvailable(invokers.get(0));
+ invoker.responseAvailable(invokers.get(1));
+ Result result = invoker.search(query, null);
+ assertEquals(1, ((GroupingListHit) result.hits().get(0)).getGroupingList().size());
+
+ }
+
private static InterleavedSearchInvoker createInterLeavedTestInvoker(List<Double> a, List<Double> b, Group group) {
SearchCluster cluster = new MockSearchCluster("!", 1, 2);
List<SearchInvoker> invokers = new ArrayList<>();
diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java
index 2acce0f8d2d..54c8c1e0522 100644
--- a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java
+++ b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java
@@ -120,7 +120,8 @@ public class MockSearchCluster extends SearchCluster {
builder.minActivedocsPercentage(88.0);
builder.minGroupCoverage(99.0);
builder.minSearchCoverage(minSearchCoverage);
- builder.distributionPolicy(DispatchConfig.DistributionPolicy.Enum.ROUNDROBIN);
+ builder.distributionPolicy(DispatchConfig.DistributionPolicy.Enum.ROUNDROBIN)
+ .mergeGroupingResultInSearchInvokerEnabled(true);
if (minSearchCoverage < 100.0) {
builder.minWaitAfterCoverageFactor(0);
builder.maxWaitAfterCoverageFactor(0.5);
diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
index 881b62d1e04..5376fa983af 100644
--- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
+++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
@@ -390,6 +390,13 @@ public class Flags {
"Takes effect immediately",
TENANT_ID);
+ public static final UnboundBooleanFlag MERGE_GROUPING_RESULT_IN_SEARCH_INVOKER = defineFeatureFlag(
+ "merge-grouping-result-in-search-invoker", false,
+ List.of("bjorncs", "baldersheim"), "2022-02-23", "2022-08-01",
+ "Merge grouping results incrementally in interleaved search invoker",
+ "Takes effect at redeployment",
+ APPLICATION_ID);
+
/** WARNING: public for testing: All flags should be defined in {@link Flags}. */
public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, List<String> owners,
String createdAt, String expiresAt, String description,
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java
index 25b3cb18ff9..c88a567c559 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/aggregation/Grouping.java
@@ -46,6 +46,8 @@ public class Grouping extends Identifiable {
// Actual root group, does not require level details.
private Group root = new Group();
+ private boolean postMergeCompleted = false;
+
/**
* <p>Constructs an empty result node. <b>NOTE:</b> This instance is broken until non-optional member data is
* set.</p>
@@ -78,7 +80,9 @@ public class Grouping extends Identifiable {
* that might have changes due to the merge.</p>
*/
public void postMerge() {
+ if (postMergeCompleted) return;
root.postMerge(groupingLevels, firstLevel, 0);
+ postMergeCompleted = true;
}
/**
diff --git a/vespalib/src/tests/util/generationhandler/CMakeLists.txt b/vespalib/src/tests/util/generationhandler/CMakeLists.txt
index 677d5caa0e6..fdf54c59854 100644
--- a/vespalib/src/tests/util/generationhandler/CMakeLists.txt
+++ b/vespalib/src/tests/util/generationhandler/CMakeLists.txt
@@ -4,5 +4,6 @@ vespa_add_executable(vespalib_generationhandler_test_app TEST
generationhandler_test.cpp
DEPENDS
vespalib
+ GTest::GTest
)
vespa_add_test(NAME vespalib_generationhandler_test_app COMMAND vespalib_generationhandler_test_app)
diff --git a/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp b/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp
index f269fe729fa..00da752a749 100644
--- a/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp
+++ b/vespalib/src/tests/util/generationhandler/generationhandler_test.cpp
@@ -1,157 +1,137 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <vespa/vespalib/testkit/testapp.h>
+#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/vespalib/util/generationhandler.h>
#include <deque>
namespace vespalib {
-typedef GenerationHandler::Guard GenGuard;
+using GenGuard = GenerationHandler::Guard;
-class Test : public vespalib::TestApp {
-private:
- void requireThatGenerationCanBeIncreased();
- void requireThatReadersCanTakeGuards();
- void requireThatGuardsCanBeCopied();
- void requireThatTheFirstUsedGenerationIsCorrect();
- void requireThatGenerationCanGrowLarge();
-public:
- int Main() override;
+class GenerationHandlerTest : public ::testing::Test {
+protected:
+ GenerationHandler gh;
+ GenerationHandlerTest();
+ ~GenerationHandlerTest() override;
};
-void
-Test::requireThatGenerationCanBeIncreased()
+GenerationHandlerTest::GenerationHandlerTest()
+ : ::testing::Test(),
+ gh()
{
- GenerationHandler gh;
- EXPECT_EQUAL(0u, gh.getCurrentGeneration());
- EXPECT_EQUAL(0u, gh.getFirstUsedGeneration());
+}
+
+GenerationHandlerTest::~GenerationHandlerTest() = default;
+
+TEST_F(GenerationHandlerTest, require_that_generation_can_be_increased)
+{
+ EXPECT_EQ(0u, gh.getCurrentGeneration());
+ EXPECT_EQ(0u, gh.getFirstUsedGeneration());
gh.incGeneration();
- EXPECT_EQUAL(1u, gh.getCurrentGeneration());
- EXPECT_EQUAL(1u, gh.getFirstUsedGeneration());
+ EXPECT_EQ(1u, gh.getCurrentGeneration());
+ EXPECT_EQ(1u, gh.getFirstUsedGeneration());
}
-void
-Test::requireThatReadersCanTakeGuards()
+TEST_F(GenerationHandlerTest, require_that_readers_can_take_guards)
{
- GenerationHandler gh;
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(0));
{
GenGuard g1 = gh.takeGuard();
- EXPECT_EQUAL(1u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(1u, gh.getGenerationRefCount(0));
{
GenGuard g2 = gh.takeGuard();
- EXPECT_EQUAL(2u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(2u, gh.getGenerationRefCount(0));
gh.incGeneration();
{
GenGuard g3 = gh.takeGuard();
- EXPECT_EQUAL(2u, gh.getGenerationRefCount(0));
- EXPECT_EQUAL(1u, gh.getGenerationRefCount(1));
- EXPECT_EQUAL(3u, gh.getGenerationRefCount());
+ EXPECT_EQ(2u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(1u, gh.getGenerationRefCount(1));
+ EXPECT_EQ(3u, gh.getGenerationRefCount());
}
- EXPECT_EQUAL(2u, gh.getGenerationRefCount(0));
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(1));
+ EXPECT_EQ(2u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(1));
gh.incGeneration();
{
GenGuard g3 = gh.takeGuard();
- EXPECT_EQUAL(2u, gh.getGenerationRefCount(0));
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(1));
- EXPECT_EQUAL(1u, gh.getGenerationRefCount(2));
+ EXPECT_EQ(2u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(1));
+ EXPECT_EQ(1u, gh.getGenerationRefCount(2));
}
- EXPECT_EQUAL(2u, gh.getGenerationRefCount(0));
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(1));
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(2));
+ EXPECT_EQ(2u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(1));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(2));
}
- EXPECT_EQUAL(1u, gh.getGenerationRefCount(0));
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(1));
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(2));
+ EXPECT_EQ(1u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(1));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(2));
}
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(0));
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(1));
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(2));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(1));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(2));
}
-void
-Test::requireThatGuardsCanBeCopied()
+TEST_F(GenerationHandlerTest, require_that_guards_can_be_copied)
{
- GenerationHandler gh;
GenGuard g1 = gh.takeGuard();
- EXPECT_EQUAL(1u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(1u, gh.getGenerationRefCount(0));
GenGuard g2(g1);
- EXPECT_EQUAL(2u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(2u, gh.getGenerationRefCount(0));
gh.incGeneration();
GenGuard g3 = gh.takeGuard();
- EXPECT_EQUAL(2u, gh.getGenerationRefCount(0));
- EXPECT_EQUAL(1u, gh.getGenerationRefCount(1));
+ EXPECT_EQ(2u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(1u, gh.getGenerationRefCount(1));
g3 = g2;
- EXPECT_EQUAL(3u, gh.getGenerationRefCount(0));
- EXPECT_EQUAL(0u, gh.getGenerationRefCount(1));
+ EXPECT_EQ(3u, gh.getGenerationRefCount(0));
+ EXPECT_EQ(0u, gh.getGenerationRefCount(1));
}
-void
-Test::requireThatTheFirstUsedGenerationIsCorrect()
+TEST_F(GenerationHandlerTest, require_that_the_first_used_generation_is_correct)
{
- GenerationHandler gh;
- EXPECT_EQUAL(0u, gh.getFirstUsedGeneration());
+ EXPECT_EQ(0u, gh.getFirstUsedGeneration());
gh.incGeneration();
- EXPECT_EQUAL(1u, gh.getFirstUsedGeneration());
+ EXPECT_EQ(1u, gh.getFirstUsedGeneration());
{
GenGuard g1 = gh.takeGuard();
gh.incGeneration();
- EXPECT_EQUAL(1u, gh.getGenerationRefCount());
- EXPECT_EQUAL(1u, gh.getFirstUsedGeneration());
+ EXPECT_EQ(1u, gh.getGenerationRefCount());
+ EXPECT_EQ(1u, gh.getFirstUsedGeneration());
}
- EXPECT_EQUAL(1u, gh.getFirstUsedGeneration());
+ EXPECT_EQ(1u, gh.getFirstUsedGeneration());
gh.updateFirstUsedGeneration(); // Only writer should call this
- EXPECT_EQUAL(0u, gh.getGenerationRefCount());
- EXPECT_EQUAL(2u, gh.getFirstUsedGeneration());
+ EXPECT_EQ(0u, gh.getGenerationRefCount());
+ EXPECT_EQ(2u, gh.getFirstUsedGeneration());
{
GenGuard g1 = gh.takeGuard();
gh.incGeneration();
gh.incGeneration();
- EXPECT_EQUAL(1u, gh.getGenerationRefCount());
- EXPECT_EQUAL(2u, gh.getFirstUsedGeneration());
+ EXPECT_EQ(1u, gh.getGenerationRefCount());
+ EXPECT_EQ(2u, gh.getFirstUsedGeneration());
{
GenGuard g2 = gh.takeGuard();
- EXPECT_EQUAL(2u, gh.getFirstUsedGeneration());
+ EXPECT_EQ(2u, gh.getFirstUsedGeneration());
}
}
- EXPECT_EQUAL(2u, gh.getFirstUsedGeneration());
+ EXPECT_EQ(2u, gh.getFirstUsedGeneration());
gh.updateFirstUsedGeneration(); // Only writer should call this
- EXPECT_EQUAL(0u, gh.getGenerationRefCount());
- EXPECT_EQUAL(4u, gh.getFirstUsedGeneration());
+ EXPECT_EQ(0u, gh.getGenerationRefCount());
+ EXPECT_EQ(4u, gh.getFirstUsedGeneration());
}
-void
-Test::requireThatGenerationCanGrowLarge()
+TEST_F(GenerationHandlerTest, require_that_generation_can_grow_large)
{
- GenerationHandler gh;
std::deque<GenGuard> guards;
for (size_t i = 0; i < 10000; ++i) {
- EXPECT_EQUAL(i, gh.getCurrentGeneration());
+ EXPECT_EQ(i, gh.getCurrentGeneration());
guards.push_back(gh.takeGuard()); // take guard on current generation
if (i >= 128) {
- EXPECT_EQUAL(i - 128, gh.getFirstUsedGeneration());
+ EXPECT_EQ(i - 128, gh.getFirstUsedGeneration());
guards.pop_front();
- EXPECT_EQUAL(128u, gh.getGenerationRefCount());
+ EXPECT_EQ(128u, gh.getGenerationRefCount());
}
gh.incGeneration();
}
}
-int
-Test::Main()
-{
- TEST_INIT("generationhandler_test");
-
- TEST_DO(requireThatGenerationCanBeIncreased());
- TEST_DO(requireThatReadersCanTakeGuards());
- TEST_DO(requireThatGuardsCanBeCopied());
- TEST_DO(requireThatTheFirstUsedGenerationIsCorrect());
- TEST_DO(requireThatGenerationCanGrowLarge());
-
- TEST_DONE();
-}
-
}
-TEST_APPHOOK(vespalib::Test);
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt b/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt
index 569db489e3c..7e5a5af79b5 100644
--- a/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt
+++ b/vespalib/src/tests/util/generationhandler_stress/CMakeLists.txt
@@ -4,5 +4,6 @@ vespa_add_executable(vespalib_generation_handler_stress_test_app
generation_handler_stress_test.cpp
DEPENDS
vespalib
+ GTest::GTest
)
-vespa_add_test(NAME vespalib_generation_handler_stress_test_app COMMAND vespalib_generation_handler_stress_test_app BENCHMARK)
+vespa_add_test(NAME vespalib_generation_handler_stress_test_app COMMAND vespalib_generation_handler_stress_test_app --smoke-test)
diff --git a/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp b/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp
index fa2c525b518..0689909da09 100644
--- a/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp
+++ b/vespalib/src/tests/util/generationhandler_stress/generation_handler_stress_test.cpp
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/log/log.h>
LOG_SETUP("generation_handler_stress_test");
-#include <vespa/vespalib/testkit/testapp.h>
+#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/vespalib/util/generationhandler.h>
#include <vespa/vespalib/util/threadstackexecutor.h>
#include <vespa/vespalib/util/size_literals.h>
@@ -10,6 +10,12 @@ using vespalib::Executor;
using vespalib::GenerationHandler;
using vespalib::ThreadStackExecutor;
+namespace {
+
+bool smoke_test = false;
+const vespalib::string smoke_test_option = "--smoke-test";
+
+}
struct WorkContext
{
@@ -21,26 +27,28 @@ struct WorkContext
}
};
-struct Fixture {
+class Fixture : public ::testing::Test {
+protected:
GenerationHandler _generationHandler;
uint32_t _readThreads;
ThreadStackExecutor _writer; // 1 write thread
- ThreadStackExecutor _readers; // multiple reader threads
+ std::unique_ptr<ThreadStackExecutor> _readers; // multiple reader threads
std::atomic<long> _readSeed;
std::atomic<long> _doneWriteWork;
std::atomic<long> _doneReadWork;
std::atomic<int> _stopRead;
bool _reportWork;
- Fixture(uint32_t readThreads = 1);
-
+ Fixture();
~Fixture();
- void readWork(const WorkContext &context);
- void writeWork(uint32_t cnt, WorkContext &context);
+ void set_read_threads(uint32_t read_threads);
+
uint32_t getReadThreads() const { return _readThreads; }
void stressTest(uint32_t writeCnt);
-
+public:
+ void readWork(const WorkContext &context);
+ void writeWork(uint32_t cnt, WorkContext &context);
private:
Fixture(const Fixture &index) = delete;
Fixture(Fixture &&index) = delete;
@@ -49,23 +57,27 @@ private:
};
-Fixture::Fixture(uint32_t readThreads)
- : _generationHandler(),
- _readThreads(readThreads),
+Fixture::Fixture()
+ : ::testing::Test(),
+ _generationHandler(),
+ _readThreads(1),
_writer(1, 128_Ki),
- _readers(readThreads, 128_Ki),
+ _readers(),
_doneWriteWork(0),
_doneReadWork(0),
_stopRead(0),
_reportWork(false)
{
+ set_read_threads(1);
}
Fixture::~Fixture()
{
- _readers.sync();
- _readers.shutdown();
+ if (_readers) {
+ _readers->sync();
+ _readers->shutdown();
+ }
_writer.sync();
_writer.shutdown();
if (_reportWork) {
@@ -75,6 +87,16 @@ Fixture::~Fixture()
}
}
+void
+Fixture::set_read_threads(uint32_t read_threads)
+{
+ if (_readers) {
+ _readers->sync();
+ _readers->shutdown();
+ }
+ _readThreads = read_threads;
+ _readers = std::make_unique<ThreadStackExecutor>(read_threads, 128_Ki);
+}
void
Fixture::readWork(const WorkContext &context)
@@ -85,7 +107,7 @@ Fixture::readWork(const WorkContext &context)
for (i = 0; i < cnt && _stopRead.load() == 0; ++i) {
auto guard = _generationHandler.takeGuard();
auto generation = context._generation.load(std::memory_order_relaxed);
- EXPECT_GREATER_EQUAL(generation, guard.getGeneration());
+ EXPECT_GE(generation, guard.getGeneration());
}
_doneReadWork += i;
LOG(info, "done %u read work", i);
@@ -150,19 +172,32 @@ Fixture::stressTest(uint32_t writeCnt)
auto context = std::make_shared<WorkContext>();
_writer.execute(std::make_unique<WriteWorkTask>(*this, writeCnt, context));
for (uint32_t i = 0; i < readThreads; ++i) {
- _readers.execute(std::make_unique<ReadWorkTask>(*this, context));
+ _readers->execute(std::make_unique<ReadWorkTask>(*this, context));
}
+ _writer.sync();
+ _readers->sync();
}
+using GenerationHandlerStressTest = Fixture;
-TEST_F("stress test, 2 readers", Fixture(2))
+TEST_F(GenerationHandlerStressTest, stress_test_2_readers)
{
- f.stressTest(1000000);
+ set_read_threads(2);
+ stressTest(smoke_test ? 10000 : 1000000);
}
-TEST_F("stress test, 4 readers", Fixture(4))
+TEST_F(GenerationHandlerStressTest, stress_test_4_readers)
{
- f.stressTest(1000000);
+ set_read_threads(4);
+ stressTest(smoke_test ? 10000 : 1000000);
}
-TEST_MAIN() { TEST_RUN_ALL(); }
+int main(int argc, char **argv) {
+ if (argc > 1 && argv[1] == smoke_test_option) {
+ smoke_test = true;
+ ++argv;
+ --argc;
+ }
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}