diff options
64 files changed, 1198 insertions, 546 deletions
diff --git a/client/go/internal/cli/cmd/cert.go b/client/go/internal/cli/cmd/cert.go index 95206b7e77d..1fa5339e42e 100644 --- a/client/go/internal/cli/cmd/cert.go +++ b/client/go/internal/cli/cmd/cert.go @@ -4,9 +4,7 @@ package cmd import ( - "errors" "fmt" - "io" "os" "path/filepath" @@ -18,8 +16,8 @@ import ( func newCertCmd(cli *CLI) *cobra.Command { var ( - noApplicationPackage bool - overwriteCertificate bool + skipApplicationPackage bool + overwriteCertificate bool ) cmd := &cobra.Command{ Use: "cert", @@ -60,11 +58,12 @@ $ vespa auth cert -a my-tenant.my-app.my-instance path/to/application/package`, SilenceUsage: true, Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return doCert(cli, overwriteCertificate, noApplicationPackage, args) + return doCert(cli, overwriteCertificate, skipApplicationPackage, args) }, } cmd.Flags().BoolVarP(&overwriteCertificate, "force", "f", false, "Force overwrite of existing certificate and private key") - cmd.Flags().BoolVarP(&noApplicationPackage, "no-add", "N", false, "Do not add certificate to the application package") + // TODO(mpolden): Stop adding certificate to application package and remove this flag + cmd.Flags().BoolVarP(&skipApplicationPackage, "no-add", "N", false, "Do not add certificate to the application package") cmd.MarkPersistentFlagRequired(applicationFlag) return cmd } @@ -95,18 +94,11 @@ $ vespa auth cert add -a my-tenant.my-app.my-instance path/to/application/packag return cmd } -func doCert(cli *CLI, overwriteCertificate, noApplicationPackage bool, args []string) error { +func doCert(cli *CLI, overwriteCertificate, skipApplicationPackage bool, args []string) error { app, err := cli.config.application() if err != nil { return err } - var pkg vespa.ApplicationPackage - if !noApplicationPackage { - pkg, err = cli.applicationPackageFrom(args, false) - if err != nil { - return err - } - } targetType, err := cli.targetType() if err != nil { return err @@ -122,11 +114,6 @@ func doCert(cli *CLI, overwriteCertificate, noApplicationPackage bool, args []st if !overwriteCertificate { hint := "Use -f flag to force overwriting" - if !noApplicationPackage { - if pkg.HasCertificate() { - return errHint(fmt.Errorf("application package %s already contains a certificate", pkg.Path), hint) - } - } if util.PathExists(privateKeyFile) { return errHint(fmt.Errorf("private key %s already exists", color.CyanString(privateKeyFile)), hint) } @@ -134,91 +121,86 @@ func doCert(cli *CLI, overwriteCertificate, noApplicationPackage bool, args []st return errHint(fmt.Errorf("certificate %s already exists", color.CyanString(certificateFile)), hint) } } - if !noApplicationPackage { - if pkg.IsZip() { - hint := "Try running 'mvn clean' before 'vespa auth cert', and then 'mvn package'" - return errHint(fmt.Errorf("cannot add certificate to compressed application package %s", pkg.Path), hint) - } - } keyPair, err := vespa.CreateKeyPair() if err != nil { return err } - var pkgCertificateFile string - if !noApplicationPackage { - pkgCertificateFile = filepath.Join(pkg.Path, "security", "clients.pem") - if err := os.MkdirAll(filepath.Dir(pkgCertificateFile), 0755); err != nil { - return fmt.Errorf("could not create security directory: %w", err) - } - if err := keyPair.WriteCertificateFile(pkgCertificateFile, overwriteCertificate); err != nil { - return fmt.Errorf("could not write certificate to application package: %w", err) - } - } if err := keyPair.WriteCertificateFile(certificateFile, overwriteCertificate); err != nil { return fmt.Errorf("could not write certificate: %w", err) } if err := keyPair.WritePrivateKeyFile(privateKeyFile, overwriteCertificate); err != nil { return fmt.Errorf("could not write private key: %w", err) } - if !noApplicationPackage { - cli.printSuccess("Certificate written to ", color.CyanString(pkgCertificateFile)) - } cli.printSuccess("Certificate written to ", color.CyanString(certificateFile)) cli.printSuccess("Private key written to ", color.CyanString(privateKeyFile)) + if !skipApplicationPackage { + return doCertAdd(cli, overwriteCertificate, args) + } return nil } func doCertAdd(cli *CLI, overwriteCertificate bool, args []string) error { - app, err := cli.config.application() - if err != nil { - return err - } pkg, err := cli.applicationPackageFrom(args, false) if err != nil { return err } - targetType, err := cli.targetType() + target, err := cli.target(targetOptions{}) if err != nil { return err } - certificateFile, err := cli.config.certificatePath(app, targetType.name) - if err != nil { - return err + if pkg.HasCertificate() && !overwriteCertificate { + return errHint(fmt.Errorf("application package %s already contains a certificate", pkg.Path), "Use -f flag to force overwriting") } + return maybeCopyCertificate(true, false, cli, target, pkg) +} - if pkg.IsZip() { - hint := "Try running 'mvn clean' before 'vespa auth cert add', and then 'mvn package'" - return errHint(fmt.Errorf("unable to add certificate to compressed application package: %s", pkg.Path), hint) +func maybeCopyCertificate(force, ignoreZip bool, cli *CLI, target vespa.Target, pkg vespa.ApplicationPackage) error { + if pkg.IsZip() && !ignoreZip { + hint := "Try running 'mvn clean', then 'vespa auth cert add' and finally 'mvn package'" + return errHint(fmt.Errorf("cannot add certificate to compressed application package: %s", pkg.Path), hint) } - - pkgCertificateFile := filepath.Join(pkg.Path, "security", "clients.pem") - if err := os.MkdirAll(filepath.Dir(pkgCertificateFile), 0755); err != nil { - return fmt.Errorf("could not create security directory: %w", err) + if force { + return copyCertificate(cli, target, pkg) } - src, err := os.Open(certificateFile) - if errors.Is(err, os.ErrNotExist) { - return errHint(fmt.Errorf("there is not key pair generated for application '%s'", app), "Try running 'vespa auth cert' to generate it") - } else if err != nil { - return fmt.Errorf("could not open certificate file: %w", err) - } - defer src.Close() - flags := os.O_CREATE | os.O_RDWR - if overwriteCertificate { - flags |= os.O_TRUNC - } else { - flags |= os.O_EXCL - } - dst, err := os.OpenFile(pkgCertificateFile, flags, 0755) - if errors.Is(err, os.ErrExist) { - return errHint(fmt.Errorf("application package %s already contains a certificate", pkg.Path), "Use -f flag to force overwriting") - } else if err != nil { - return fmt.Errorf("could not open application certificate file for writing: %w", err) + if pkg.HasCertificate() { + return nil } - if _, err := io.Copy(dst, src); err != nil { - return fmt.Errorf("could not copy certificate file to application: %w", err) + if cli.isTerminal() { + cli.printWarning("Application package does not contain " + color.CyanString("security/clients.pem") + ", which is required for deployments to Vespa Cloud") + ok, err := cli.confirm("Do you want to copy the certificate of application " + color.GreenString(target.Deployment().Application.String()) + " into this application package?") + if err != nil { + return err + } + if ok { + return copyCertificate(cli, target, pkg) + } } + return errHint(fmt.Errorf("deployment to Vespa Cloud requires certificate in application package"), + "See https://cloud.vespa.ai/en/security/guide", + "Pass --add-cert to use the certificate of the current application") +} - cli.printSuccess("Certificate written to ", color.CyanString(pkgCertificateFile)) - return nil +func copyCertificate(cli *CLI, target vespa.Target, pkg vespa.ApplicationPackage) error { + tlsOptions, err := cli.config.readTLSOptions(target.Deployment().Application, target.Type()) + if err != nil { + return err + } + hint := "Try generating the certificate with 'vespa auth cert'" + if tlsOptions.CertificateFile == "" { + return errHint(fmt.Errorf("no certificate exists for "+target.Deployment().Application.String()), hint) + } + data, err := os.ReadFile(tlsOptions.CertificateFile) + if err != nil { + return errHint(fmt.Errorf("could not read certificate file: %w", err)) + } + dstPath := filepath.Join(pkg.Path, "security", "clients.pem") + if err := os.MkdirAll(filepath.Dir(dstPath), 0755); err != nil { + return fmt.Errorf("could not create security directory: %w", err) + } + err = util.AtomicWriteFile(dstPath, data) + if err == nil { + cli.printSuccess("Copied certificate from ", tlsOptions.CertificateFile, " to ", dstPath) + } + return err } diff --git a/client/go/internal/cli/cmd/cert_test.go b/client/go/internal/cli/cmd/cert_test.go index d6b47083e7b..0a20ab8eb1a 100644 --- a/client/go/internal/cli/cmd/cert_test.go +++ b/client/go/internal/cli/cmd/cert_test.go @@ -22,13 +22,21 @@ func TestCert(t *testing.T) { }) } +func configureCloud(t *testing.T, cli *CLI) { + require.Nil(t, cli.Run("config", "set", "application", "t1.a1.i1")) + require.Nil(t, cli.Run("config", "set", "target", "cloud")) + require.Nil(t, cli.Run("auth", "api-key")) +} + func testCert(t *testing.T, subcommand []string) { appDir, pkgDir := mock.ApplicationPackageDir(t, false, false) - cli, stdout, stderr := newTestCLI(t) - args := append(subcommand, "-a", "t1.a1.i1", pkgDir) - err := cli.Run(args...) - assert.Nil(t, err) + cli, stdout, _ := newTestCLI(t) + configureCloud(t, cli) + stdout.Reset() + + args := append(subcommand, pkgDir) + require.Nil(t, cli.Run(args...)) app, err := vespa.ApplicationFromString("t1.a1.i1") assert.Nil(t, err) @@ -38,12 +46,7 @@ func testCert(t *testing.T, subcommand []string) { certificate := filepath.Join(homeDir, app.String(), "data-plane-public-cert.pem") privateKey := filepath.Join(homeDir, app.String(), "data-plane-private-key.pem") - assert.Equal(t, fmt.Sprintf("Success: Certificate written to %s\nSuccess: Certificate written to %s\nSuccess: Private key written to %s\n", pkgCertificate, certificate, privateKey), stdout.String()) - - args = append(subcommand, "-a", "t1.a1.i1", pkgDir) - err = cli.Run(args...) - assert.NotNil(t, err) - assert.Contains(t, stderr.String(), fmt.Sprintf("Error: application package %s already contains a certificate", appDir)) + assert.Equal(t, fmt.Sprintf("Success: Certificate written to %s\nSuccess: Private key written to %s\nSuccess: Copied certificate from %s to %s\n", certificate, privateKey, certificate, pkgCertificate), stdout.String()) } func TestCertCompressedPackage(t *testing.T) { @@ -61,8 +64,11 @@ func testCertCompressedPackage(t *testing.T, subcommand []string) { assert.Nil(t, err) cli, stdout, stderr := newTestCLI(t) + configureCloud(t, cli) + stdout.Reset() + stderr.Reset() - args := append(subcommand, "-a", "t1.a1.i1", pkgDir) + args := append(subcommand, pkgDir) err = cli.Run(args...) assert.NotNil(t, err) assert.Contains(t, stderr.String(), "Error: cannot add certificate to compressed application package") @@ -70,7 +76,7 @@ func testCertCompressedPackage(t *testing.T, subcommand []string) { err = os.Remove(zipFile) assert.Nil(t, err) - args = append(subcommand, "-f", "-a", "t1.a1.i1", pkgDir) + args = append(subcommand, "-f", pkgDir) err = cli.Run(args...) assert.Nil(t, err) assert.Contains(t, stdout.String(), "Success: Certificate written to") @@ -79,30 +85,30 @@ func testCertCompressedPackage(t *testing.T, subcommand []string) { func TestCertAdd(t *testing.T) { cli, stdout, stderr := newTestCLI(t) - err := cli.Run("auth", "cert", "-N", "-a", "t1.a1.i1") - assert.Nil(t, err) + configureCloud(t, cli) + stdout.Reset() + require.Nil(t, cli.Run("auth", "cert", "-N")) appDir, pkgDir := mock.ApplicationPackageDir(t, false, false) stdout.Reset() - err = cli.Run("auth", "cert", "add", "-a", "t1.a1.i1", pkgDir) - assert.Nil(t, err) + require.Nil(t, cli.Run("auth", "cert", "add", pkgDir)) pkgCertificate := filepath.Join(appDir, "security", "clients.pem") - assert.Equal(t, fmt.Sprintf("Success: Certificate written to %s\n", pkgCertificate), stdout.String()) + homeDir := cli.config.homeDir + certificate := filepath.Join(homeDir, "t1.a1.i1", "data-plane-public-cert.pem") + assert.Equal(t, fmt.Sprintf("Success: Copied certificate from %s to %s\n", certificate, pkgCertificate), stdout.String()) - err = cli.Run("auth", "cert", "add", "-a", "t1.a1.i1", pkgDir) - assert.NotNil(t, err) + require.NotNil(t, cli.Run("auth", "cert", "add", pkgDir)) assert.Contains(t, stderr.String(), fmt.Sprintf("Error: application package %s already contains a certificate", appDir)) stdout.Reset() - err = cli.Run("auth", "cert", "add", "-f", "-a", "t1.a1.i1", pkgDir) - assert.Nil(t, err) - assert.Equal(t, fmt.Sprintf("Success: Certificate written to %s\n", pkgCertificate), stdout.String()) + require.Nil(t, cli.Run("auth", "cert", "add", "-f", pkgDir)) + assert.Equal(t, fmt.Sprintf("Success: Copied certificate from %s to %s\n", certificate, pkgCertificate), stdout.String()) } func TestCertNoAdd(t *testing.T) { cli, stdout, stderr := newTestCLI(t) - - err := cli.Run("auth", "cert", "-N", "-a", "t1.a1.i1") - assert.Nil(t, err) + configureCloud(t, cli) + stdout.Reset() + require.Nil(t, cli.Run("auth", "cert", "-N")) homeDir := cli.config.homeDir app, err := vespa.ApplicationFromString("t1.a1.i1") @@ -112,18 +118,15 @@ func TestCertNoAdd(t *testing.T) { privateKey := filepath.Join(homeDir, app.String(), "data-plane-private-key.pem") assert.Equal(t, fmt.Sprintf("Success: Certificate written to %s\nSuccess: Private key written to %s\n", certificate, privateKey), stdout.String()) - err = cli.Run("auth", "cert", "-N", "-a", "t1.a1.i1") - assert.NotNil(t, err) + require.NotNil(t, cli.Run("auth", "cert", "-N")) assert.Contains(t, stderr.String(), fmt.Sprintf("Error: private key %s already exists", privateKey)) require.Nil(t, os.Remove(privateKey)) stderr.Reset() - err = cli.Run("auth", "cert", "-N", "-a", "t1.a1.i1") - assert.NotNil(t, err) + require.NotNil(t, cli.Run("auth", "cert", "-N")) assert.Contains(t, stderr.String(), fmt.Sprintf("Error: certificate %s already exists", certificate)) stdout.Reset() - err = cli.Run("auth", "cert", "-N", "-f", "-a", "t1.a1.i1") - assert.Nil(t, err) + require.Nil(t, cli.Run("auth", "cert", "-N", "-f")) assert.Equal(t, fmt.Sprintf("Success: Certificate written to %s\nSuccess: Private key written to %s\n", certificate, privateKey), stdout.String()) } diff --git a/client/go/internal/cli/cmd/deploy.go b/client/go/internal/cli/cmd/deploy.go index 76027744268..35b9ee0f300 100644 --- a/client/go/internal/cli/cmd/deploy.go +++ b/client/go/internal/cli/cmd/deploy.go @@ -20,6 +20,7 @@ func newDeployCmd(cli *CLI) *cobra.Command { var ( logLevelArg string versionArg string + copyCert bool ) cmd := &cobra.Command{ Use: "deploy [application-directory]", @@ -67,16 +68,18 @@ $ vespa deploy -t cloud -z perf.aws-us-east-1c`, } opts.Version = version } - + if target.Type() == vespa.TargetCloud { + if err := maybeCopyCertificate(copyCert, true, cli, target, pkg); err != nil { + return err + } + } var result vespa.PrepareResult - err = cli.spinner(cli.Stderr, "Uploading application package ...", func() error { + if err := cli.spinner(cli.Stderr, "Uploading application package ...", func() error { result, err = vespa.Deploy(opts) return err - }) - if err != nil { + }); err != nil { return err } - log.Println() if opts.Target.IsCloud() { cli.printSuccess("Triggered deployment of ", color.CyanString(pkg.Path), " with run ID ", color.CyanString(strconv.FormatInt(result.ID, 10))) @@ -97,6 +100,7 @@ $ vespa deploy -t cloud -z perf.aws-us-east-1c`, } cmd.Flags().StringVarP(&logLevelArg, "log-level", "l", "error", `Log level for Vespa logs. Must be "error", "warning", "info" or "debug"`) cmd.Flags().StringVarP(&versionArg, "version", "V", "", `Override the Vespa runtime version to use in Vespa Cloud`) + cmd.Flags().BoolVarP(©Cert, "add-cert", "A", false, `Copy certificate of the configured application to the current application package`) return cmd } diff --git a/client/go/internal/cli/cmd/deploy_test.go b/client/go/internal/cli/cmd/deploy_test.go index 9eaf878bc5e..78834b7185b 100644 --- a/client/go/internal/cli/cmd/deploy_test.go +++ b/client/go/internal/cli/cmd/deploy_test.go @@ -5,14 +5,62 @@ package cmd import ( + "bytes" + "path/filepath" "strconv" + "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/vespa-engine/vespa/client/go/internal/mock" "github.com/vespa-engine/vespa/client/go/internal/vespa" ) +func TestDeployCloud(t *testing.T) { + pkgDir := filepath.Join(t.TempDir(), "app") + createApplication(t, pkgDir, false, false) + + cli, stdout, stderr := newTestCLI(t, "CI=true", "NO_COLOR=true") + httpClient := &mock.HTTPClient{} + httpClient.NextResponseString(200, `ok`) + cli.httpClient = httpClient + + app := vespa.ApplicationID{Tenant: "t1", Application: "a1", Instance: "i1"} + assert.Nil(t, cli.Run("config", "set", "application", app.String())) + assert.Nil(t, cli.Run("config", "set", "target", "cloud")) + assert.Nil(t, cli.Run("auth", "api-key")) + assert.Nil(t, cli.Run("auth", "cert", "--no-add")) + + stderr.Reset() + require.NotNil(t, cli.Run("deploy", pkgDir)) + certError := `Error: deployment to Vespa Cloud requires certificate in application package +Hint: See https://cloud.vespa.ai/en/security/guide +Hint: Pass --add-cert to use the certificate of the current application +` + assert.Equal(t, certError, stderr.String()) + + require.Nil(t, cli.Run("deploy", "--add-cert", pkgDir)) + assert.Contains(t, stdout.String(), "Success: Triggered deployment") + + // Answer interactive certificate copy prompt + stdout.Reset() + stderr.Reset() + cli.isTerminal = func() bool { return true } + pkgDir2 := filepath.Join(t.TempDir(), "app") + createApplication(t, pkgDir2, false, false) + + var buf bytes.Buffer + buf.WriteString("wat\nthe\nfck\nn\n") + cli.Stdin = &buf + require.NotNil(t, cli.Run("deploy", "--add-cert=false", pkgDir2)) + warning := "Warning: Application package does not contain security/clients.pem, which is required for deployments to Vespa Cloud\n" + assert.Equal(t, warning+strings.Repeat("Error: please answer 'Y' or 'n'\n", 3)+certError, stderr.String()) + buf.WriteString("y\n") + require.Nil(t, cli.Run("deploy", "--add-cert=false", pkgDir2)) + assert.Contains(t, stdout.String(), "Success: Triggered deployment") +} + func TestPrepareZip(t *testing.T) { assertPrepare("testdata/applications/withTarget/target/application.zip", []string{"prepare", "testdata/applications/withTarget/target/application.zip"}, t) @@ -42,7 +90,7 @@ func TestDeployZipWithURLTargetArgument(t *testing.T) { assertDeployRequestMade("http://target:19071", client, t) } -func TestDeployZipWitLocalTargetArgument(t *testing.T) { +func TestDeployZipWithLocalTargetArgument(t *testing.T) { assertDeploy("testdata/applications/withTarget/target/application.zip", []string{"deploy", "testdata/applications/withTarget/target/application.zip", "-t", "local"}, t) } diff --git a/client/go/internal/cli/cmd/document.go b/client/go/internal/cli/cmd/document.go index 0ed68e30ced..6a07121a13b 100644 --- a/client/go/internal/cli/cmd/document.go +++ b/client/go/internal/cli/cmd/document.go @@ -160,8 +160,8 @@ func newDocumentCmd(cli *CLI) *cobra.Command { ) cmd := &cobra.Command{ Use: "document json-file", - Short: "Issue a document operation to Vespa", - Long: `Issue a document operation to Vespa. + Short: "Issue a single document operation to Vespa", + Long: `Issue a single document operation to Vespa. The operation must be on the format documented in https://docs.vespa.ai/en/reference/document-json-format.html#document-operations diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go index 6d368cb210b..cad3568a89f 100644 --- a/client/go/internal/cli/cmd/feed.go +++ b/client/go/internal/cli/cmd/feed.go @@ -56,8 +56,8 @@ func newFeedCmd(cli *CLI) *cobra.Command { var options feedOptions cmd := &cobra.Command{ Use: "feed FILE [FILE]...", - Short: "Feed documents to a Vespa cluster", - Long: `Feed documents to a Vespa cluster. + Short: "Feed multiple document operations to a Vespa cluster", + Long: `Feed multiple document operations to a Vespa cluster. This command can be used to feed large amounts of documents to a Vespa cluster efficiently. diff --git a/client/go/internal/cli/cmd/login.go b/client/go/internal/cli/cmd/login.go index d2075bdfcf0..54c0dfef770 100644 --- a/client/go/internal/cli/cmd/login.go +++ b/client/go/internal/cli/cmd/login.go @@ -4,7 +4,6 @@ import ( "fmt" "log" "os" - "strings" "time" "github.com/pkg/browser" @@ -46,7 +45,10 @@ func newLoginCmd(cli *CLI) *cobra.Command { log.Printf("Your Device Confirmation code is: %s\n", state.UserCode) - auto_open := confirm(cli, "Automatically open confirmation page in your default browser?") + auto_open, err := cli.confirm("Automatically open confirmation page in your default browser?") + if err != nil { + return err + } if auto_open { log.Printf("Opened link in your browser: %s\n", state.VerificationURI) @@ -90,22 +92,3 @@ func newLoginCmd(cli *CLI) *cobra.Command { }, } } - -func confirm(cli *CLI, question string) bool { - for { - var answer string - - fmt.Fprintf(cli.Stdout, "%s [Y/n] ", question) - fmt.Fscanln(cli.Stdin, &answer) - - answer = strings.TrimSpace(strings.ToLower(answer)) - - if answer == "y" || answer == "" { - return true - } else if answer == "n" { - return false - } else { - log.Printf("Please answer Y or N.\n") - } - } -} diff --git a/client/go/internal/cli/cmd/prod.go b/client/go/internal/cli/cmd/prod.go index 318dcefe7f7..6daa8db6e81 100644 --- a/client/go/internal/cli/cmd/prod.go +++ b/client/go/internal/cli/cmd/prod.go @@ -103,7 +103,8 @@ https://cloud.vespa.ai/en/reference/deployment`, } func newProdDeployCmd(cli *CLI) *cobra.Command { - return &cobra.Command{ + copyCert := false + cmd := &cobra.Command{ Use: "deploy", Aliases: []string{"submit"}, // TODO: Remove in Vespa 9 Short: "Deploy an application to production", @@ -145,6 +146,9 @@ $ vespa prod deploy`, if err != nil { return err } + if err := maybeCopyCertificate(copyCert, true, cli, target, pkg); err != nil { + return err + } if err := vespa.Submit(opts); err != nil { return fmt.Errorf("could not deploy application: %w", err) } else { @@ -155,6 +159,8 @@ $ vespa prod deploy`, return nil }, } + cmd.Flags().BoolVarP(©Cert, "add-cert", "A", false, `Copy certificate of the configured application to the current application package`) + return cmd } func writeWithBackup(stdout io.Writer, pkg vespa.ApplicationPackage, filename, contents string) error { diff --git a/client/go/internal/cli/cmd/prod_test.go b/client/go/internal/cli/cmd/prod_test.go index 6a6cc494dcb..a01056b7178 100644 --- a/client/go/internal/cli/cmd/prod_test.go +++ b/client/go/internal/cli/cmd/prod_test.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/vespa-engine/vespa/client/go/internal/mock" "github.com/vespa-engine/vespa/client/go/internal/util" "github.com/vespa-engine/vespa/client/go/internal/vespa" @@ -172,19 +171,7 @@ func prodDeploy(pkgDir string, t *testing.T) { assert.Nil(t, cli.Run("config", "set", "application", app.String())) assert.Nil(t, cli.Run("config", "set", "target", "cloud")) assert.Nil(t, cli.Run("auth", "api-key")) - assert.Nil(t, cli.Run("auth", "cert", pkgDir)) - - // Remove certificate as it's not required for submission (but it must be part of the application package) - if path, err := cli.config.privateKeyPath(app, vespa.TargetCloud); err == nil { - os.RemoveAll(path) - } else { - require.Nil(t, err) - } - if path, err := cli.config.certificatePath(app, vespa.TargetCloud); err == nil { - os.RemoveAll(path) - } else { - require.Nil(t, err) - } + assert.Nil(t, cli.Run("auth", "cert", "--no-add")) // Zipping requires relative paths, so must let command run from pkgDir, then reset cwd for subsequent tests. if cwd, err := os.Getwd(); err != nil { @@ -198,11 +185,11 @@ func prodDeploy(pkgDir string, t *testing.T) { stdout.Reset() cli.Environment["VESPA_CLI_API_KEY_FILE"] = filepath.Join(cli.config.homeDir, "t1.api-key.pem") - assert.Nil(t, cli.Run("prod", "deploy")) + assert.Nil(t, cli.Run("prod", "deploy", "--add-cert")) assert.Contains(t, stdout.String(), "Success: Deployed") assert.Contains(t, stdout.String(), "See https://console.vespa-cloud.com/tenant/t1/application/a1/prod/deployment for deployment progress") stdout.Reset() - assert.Nil(t, cli.Run("prod", "submit")) // old variant also works + assert.Nil(t, cli.Run("prod", "submit", "--add-cert")) // old variant also works assert.Contains(t, stdout.String(), "Success: Deployed") assert.Contains(t, stdout.String(), "See https://console.vespa-cloud.com/tenant/t1/application/a1/prod/deployment for deployment progress") } @@ -218,7 +205,7 @@ func TestProdDeployWithJava(t *testing.T) { assert.Nil(t, cli.Run("config", "set", "application", "t1.a1.i1")) assert.Nil(t, cli.Run("config", "set", "target", "cloud")) assert.Nil(t, cli.Run("auth", "api-key")) - assert.Nil(t, cli.Run("auth", "cert", pkgDir)) + assert.Nil(t, cli.Run("auth", "cert", "--no-add")) // Copy an application package pre-assembled with mvn package testAppDir := filepath.Join("testdata", "applications", "withDeployment", "target") @@ -229,7 +216,7 @@ func TestProdDeployWithJava(t *testing.T) { stdout.Reset() cli.Environment["VESPA_CLI_API_KEY_FILE"] = filepath.Join(cli.config.homeDir, "t1.api-key.pem") - assert.Nil(t, cli.Run("prod", "submit", pkgDir)) + assert.Nil(t, cli.Run("prod", "deploy", pkgDir)) assert.Contains(t, stdout.String(), "Success: Deployed") assert.Contains(t, stdout.String(), "See https://console.vespa-cloud.com/tenant/t1/application/a1/prod/deployment for deployment progress") } @@ -245,7 +232,7 @@ func TestProdDeployInvalidZip(t *testing.T) { assert.Nil(t, cli.Run("config", "set", "application", "t1.a1.i1")) assert.Nil(t, cli.Run("config", "set", "target", "cloud")) assert.Nil(t, cli.Run("auth", "api-key")) - assert.Nil(t, cli.Run("auth", "cert", pkgDir)) + assert.Nil(t, cli.Run("auth", "cert", "--no-add")) // Copy an invalid application package containing relative file names testAppDir := filepath.Join("testdata", "applications", "withInvalidEntries", "target") @@ -254,7 +241,7 @@ func TestProdDeployInvalidZip(t *testing.T) { testZipFile := filepath.Join(testAppDir, "application-test.zip") copyFile(t, filepath.Join(pkgDir, "target", "application-test.zip"), testZipFile) - assert.NotNil(t, cli.Run("prod", "submit", pkgDir)) + assert.NotNil(t, cli.Run("prod", "deploy", pkgDir)) assert.Equal(t, "Error: found invalid path inside zip: ../../../../../../../tmp/foo\n", stderr.String()) } diff --git a/client/go/internal/cli/cmd/root.go b/client/go/internal/cli/cmd/root.go index 17c4fc41625..41de0bfc9d3 100644 --- a/client/go/internal/cli/cmd/root.go +++ b/client/go/internal/cli/cmd/root.go @@ -105,8 +105,7 @@ func New(stdout, stderr io.Writer, environment []string) (*CLI, error) { Short: "The command-line tool for Vespa.ai", Long: `The command-line tool for Vespa.ai. -Use it on Vespa instances running locally, remotely or in the cloud. -Prefer web service API's to this in production. +Use it on Vespa instances running locally, remotely or in Vespa Cloud. Vespa documentation: https://docs.vespa.ai @@ -301,6 +300,26 @@ func (c *CLI) printWarning(msg interface{}, hints ...string) { } } +func (c *CLI) confirm(question string) (bool, error) { + if !c.isTerminal() { + return false, fmt.Errorf("terminal is not interactive") + } + for { + var answer string + fmt.Fprintf(c.Stdout, "%s [Y/n] ", question) + fmt.Fscanln(c.Stdin, &answer) + answer = strings.TrimSpace(strings.ToLower(answer)) + switch answer { + case "y", "": + return true, nil + case "n": + return false, nil + default: + c.printErr(fmt.Errorf("please answer 'Y' or 'n'")) + } + } +} + // target creates a target according the configuration of this CLI and given opts. func (c *CLI) target(opts targetOptions) (vespa.Target, error) { targetType, err := c.targetType() diff --git a/client/go/internal/cli/cmd/visit.go b/client/go/internal/cli/cmd/visit.go index 1875c768c60..a588474bd2b 100644 --- a/client/go/internal/cli/cmd/visit.go +++ b/client/go/internal/cli/cmd/visit.go @@ -89,10 +89,10 @@ func newVisitCmd(cli *CLI) *cobra.Command { ) cmd := &cobra.Command{ Use: "visit", - Short: "Visit and print all documents in a vespa cluster", - Long: `Run visiting to retrieve all documents. + Short: "Visit and print all documents in a Vespa cluster", + Long: `Visit and print all documents in a Vespa cluster. -By default prints each document received on its own line (JSON-L format). +By default prints each document received on its own line (JSONL format). `, Example: `$ vespa visit # get documents from any cluster $ vespa visit --content-cluster search # get documents from cluster named "search" diff --git a/client/go/internal/vespa/deploy.go b/client/go/internal/vespa/deploy.go index 4531af75737..8b2cb6ea05d 100644 --- a/client/go/internal/vespa/deploy.go +++ b/client/go/internal/vespa/deploy.go @@ -21,7 +21,11 @@ import ( "github.com/vespa-engine/vespa/client/go/internal/version" ) -var DefaultApplication = ApplicationID{Tenant: "default", Application: "application", Instance: "default"} +var ( + DefaultApplication = ApplicationID{Tenant: "default", Application: "application", Instance: "default"} + DefaultZone = ZoneID{Environment: "prod", Region: "default"} + DefaultDeployment = Deployment{Application: DefaultApplication, Zone: DefaultZone} +) type ApplicationID struct { Tenant string diff --git a/client/go/internal/vespa/target_custom.go b/client/go/internal/vespa/target_custom.go index 93397287ac8..fd0af0e8d53 100644 --- a/client/go/internal/vespa/target_custom.go +++ b/client/go/internal/vespa/target_custom.go @@ -36,7 +36,7 @@ func (t *customTarget) Type() string { return t.targetType } func (t *customTarget) IsCloud() bool { return false } -func (t *customTarget) Deployment() Deployment { return Deployment{} } +func (t *customTarget) Deployment() Deployment { return DefaultDeployment } func (t *customTarget) createService(name string) (*Service, error) { switch name { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java index 2217b58c508..2deaf81d338 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java @@ -49,7 +49,7 @@ class CloudDataPlaneFilter extends Filter implements CloudDataPlaneFilterConfig. var clientsCfg = clients.stream() .map(x -> new CloudDataPlaneFilterConfig.Clients.Builder() .id(x.id()) - .certificates(X509CertificateUtils.toPem(x.certificates())) + .certificates(x.certificates().stream().map(X509CertificateUtils::toPem).toList()) .tokens(tokensConfig(x.tokens())) .permissions(x.permissions())) .toList(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 2b5232eba8c..00feb0a1c76 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -527,7 +527,8 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { private Optional<Client> getClient(Element clientElement, DeployState state) { String clientId = XML.attribute("id", clientElement).orElseThrow(); - if (clientId.startsWith("_")) throw new IllegalArgumentException("Invalid client id '%s', id cannot start with '_'".formatted(clientId)); + if (clientId.startsWith("_")) + throw new IllegalArgumentException("Invalid client id '%s', id cannot start with '_'".formatted(clientId)); List<String> permissions = XML.attribute("permissions", clientElement) .map(p -> p.split(",")).stream() .flatMap(Arrays::stream) @@ -554,15 +555,16 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { var tokenId = elem.getAttribute("id"); var token = knownTokens.get(tokenId); if (token == null) - throw new IllegalArgumentException( - "Token '%s' for client '%s' does not exist".formatted(tokenId, clientId)); + log.logApplicationPackage( + WARNING, "Token '%s' for client '%s' does not exist".formatted(tokenId, clientId)); return token; }) .filter(token -> { + if (token == null) return false; boolean empty = token.versions().isEmpty(); if (empty) log.logApplicationPackage( - WARNING, "Token '%s' for client '%s' has no activate versions" + WARNING, "Token '%s' for client '%s' has no active versions" .formatted(token.tokenId(), clientId)); return !empty; }) @@ -622,21 +624,11 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { .map(clientAuth -> clientAuth == AccessControl.ClientAuthentication.need) .orElse(false); - // TODO (mortent): Implement token support in model - boolean enableTokenSupport = deployState.featureFlags().enableDataplaneProxy(); + boolean enableTokenSupport = deployState.featureFlags().enableDataplaneProxy() + && cluster.getClients().stream().anyMatch(c -> !c.tokens().isEmpty()); // Set up component to generate proxy cert if token support is enabled if (enableTokenSupport) { - var tokenChain = new HttpFilterChain("cloud-data-plane-token", HttpFilterChain.Type.SYSTEM); - tokenChain.addInnerComponent(new Filter( - new ChainedComponentModel( - new BundleInstantiationSpecification( - new ComponentSpecification("com.yahoo.jdisc.http.filter.security.misc.BlockingRequestFilter"), - null, new ComponentSpecification("jdisc-security-filters")), - Dependencies.emptyDependencies()))); - - cluster.getHttp().getFilterChains().add(tokenChain); - cluster.addSimpleComponent(DataplaneProxyCredentials.class); cluster.addSimpleComponent(DataplaneProxyService.class); @@ -646,6 +638,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { endpointCertificateSecrets.key()); cluster.addComponent(dataplaneProxy); } + connectorFactory = authorizeClient ? HostedSslConnectorFactory.withProvidedCertificateAndTruststore( serverName, endpointCertificateSecrets, X509CertificateUtils.toPem(clientCertificates), diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilterTest.java index 5bb0254f1cc..e11eec1ffd7 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilterTest.java @@ -88,6 +88,7 @@ public class CloudDataPlaneFilterTest extends ContainerModelBuilderTestBase { CloudDataPlaneFilterConfig.Clients client = clients.get(0); assertEquals("foo", client.id()); assertIterableEquals(List.of("read", "write"), client.permissions()); + assertTrue(client.tokens().isEmpty()); assertIterableEquals(List.of(X509CertificateUtils.toPem(certificate)), client.certificates()); ConnectorConfig connectorConfig = connectorConfig(); @@ -144,6 +145,7 @@ public class CloudDataPlaneFilterTest extends ContainerModelBuilderTestBase { var tokenClient = cfg.clients().stream().filter(c -> c.id().equals("bar")).findAny().orElse(null); assertNotNull(tokenClient); assertEquals(List.of("read"), tokenClient.permissions()); + assertTrue(tokenClient.certificates().isEmpty()); var expectedTokenCfg = tokenConfig( "my-token", List.of("myfingerprint1", "myfingerprint2"), List.of("myaccesshash1", "myaccesshash2")); assertEquals(List.of(expectedTokenCfg), tokenClient.tokens()); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/versions/VespaVersion.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/versions/VespaVersion.java index 45c00848407..b03098bf18f 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/versions/VespaVersion.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/versions/VespaVersion.java @@ -49,13 +49,13 @@ public record VespaVersion(Version version, if (nonCanaryApplicationsBroken(statistics.version(), failingOnThis, productionOnThis)) return Confidence.broken; - // 'low' unless all canary applications are upgraded - if (productionOnThis.with(UpgradePolicy.canary).size() < all.withProductionDeployment().with(UpgradePolicy.canary).size()) + // 'low' unless all unpinned canary applications are upgraded + if (productionOnThis.with(UpgradePolicy.canary).unpinned().size() < all.withProductionDeployment().with(UpgradePolicy.canary).unpinned().size()) return Confidence.low; - // 'high' if 90% of all default upgrade applications upgraded - if (productionOnThis.with(UpgradePolicy.defaultPolicy).groupingBy(TenantAndApplicationId::from).size() >= - all.withProductionDeployment().with(UpgradePolicy.defaultPolicy).groupingBy(TenantAndApplicationId::from).size() * 0.9) + // 'high' if 90% of all unpinned default upgrade applications upgraded + if (productionOnThis.with(UpgradePolicy.defaultPolicy).unpinned().groupingBy(TenantAndApplicationId::from).size() >= + all.withProductionDeployment().with(UpgradePolicy.defaultPolicy).unpinned().groupingBy(TenantAndApplicationId::from).size() * 0.9) return Confidence.high; return Confidence.normal; diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index c1e12db3a81..fb3ea5c1ab5 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -408,7 +408,7 @@ public class Flags { ZONE_ID); public static final UnboundListFlag<String> WEIGHTED_ENDPOINT_RECORD_TTL = defineListFlag( - "weighted-endpoint-record-ttl", List.of(), String.class, List.of("jonmv"), "2023-05-16", "2023-06-16", + "weighted-endpoint-record-ttl", List.of(), String.class, List.of("jonmv"), "2023-05-16", "2023-09-01", "A list of endpoints and custom TTLs, on the form \"endpoint-fqdn:TTL-seconds\". " + "Where specified, CNAME records are used instead of the default ALIAS records, which have a default 60s TTL.", "Takes effect at redeployment from controller"); diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilter.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilter.java index 07f586b2123..96602fcd899 100644 --- a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilter.java +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilter.java @@ -18,7 +18,6 @@ import com.yahoo.security.token.TokenCheckHash; import com.yahoo.security.token.TokenDomain; import com.yahoo.security.token.TokenFingerprint; -import java.nio.charset.StandardCharsets; import java.security.Principal; import java.security.cert.X509Certificate; import java.util.ArrayList; @@ -98,11 +97,14 @@ public class CloudDataPlaneFilter extends JsonSecurityRequestFilterBase { if (!c.certificates().isEmpty()) { List<X509Certificate> certs; try { - certs = c.certificates().stream().map(X509CertificateUtils::fromPem).toList(); + certs = c.certificates().stream() + .flatMap(pem -> X509CertificateUtils.certificateListFromPem(pem).stream()).toList(); } catch (Exception e) { throw new IllegalArgumentException( "Client '%s' contains invalid X.509 certificate PEM: %s".formatted(c.id(), e.toString()), e); } + if (certs.isEmpty()) throw new IllegalArgumentException( + "Client '%s' certificate PEM contains no valid X.509 entries".formatted(c.id())); clients.add(new Client(c.id(), permissions, certs, Map.of())); hasClientRequiringCertificate = true; } else { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainer.java index dd4839d131a..4c5ea45f3ec 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainer.java @@ -40,9 +40,11 @@ public class NodeMetricsDbMaintainer extends NodeRepositoryMaintainer { Set<ApplicationId> applications = activeNodesByApplication().keySet(); if (applications.isEmpty()) return 1.0; - long pauseMs = interval().toMillis() / applications.size() - 1; // spread requests over interval + long pauseMs = interval().toMillis() / Math.max(4, applications.size()); // spread requests over interval int done = 0; for (ApplicationId application : applications) { + if (shuttingDown()) return asSuccessFactorDeviation(attempts, failures.get()); + attempts++; metricsFetcher.fetchMetrics(application) .whenComplete((metricsResponse, exception) -> handleResponse(metricsResponse, diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java index 387e787c754..7a2508729ed 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java @@ -24,12 +24,14 @@ import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; import static com.yahoo.config.provision.NodeType.confighost; import static com.yahoo.config.provision.NodeType.controllerhost; import static com.yahoo.config.provision.NodeType.proxyhost; +import static java.util.function.Predicate.not; /** * This handles IP address configuration and allocation. @@ -117,6 +119,7 @@ public record IP() { for (var other : sortedNodes) { if (node.equals(other)) continue; if (canAssignIpOf(other, node)) continue; + Predicate<String> sharedIpSpace = other.cloudAccount().equals(node.cloudAccount()) ? __ -> true : IP::isPublic; var addresses = new HashSet<>(node.ipConfig().primary()); var otherAddresses = new HashSet<>(other.ipConfig().primary()); @@ -124,6 +127,7 @@ public record IP() { addresses.addAll(node.ipConfig().pool().asSet()); otherAddresses.addAll(other.ipConfig().pool().asSet()); } + otherAddresses.removeIf(not(sharedIpSpace)); otherAddresses.retainAll(addresses); if (!otherAddresses.isEmpty()) throw new IllegalArgumentException("Cannot assign " + addresses + " to " + node.hostname() + @@ -463,4 +467,10 @@ public record IP() { return ipAddress.contains(":"); } + /** Returns whether given string is a public IP address */ + public static boolean isPublic(String ip) { + InetAddress address = parse(ip); + return ! address.isLoopbackAddress() && ! address.isLinkLocalAddress() && ! address.isSiteLocalAddress(); + } + } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java index ee7650da8c3..8dc9619e5d0 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java @@ -139,9 +139,7 @@ public class CapacityPolicies { new Version(8, 129, 4), new NodeResources(0.25, 1.32, 10, 0.3))); else // arm64 nodes need more memory - return versioned(clusterSpec, Map.of(new Version(0), new NodeResources(0.25, 1.50, 10, 0.3), - new Version(8, 129, 4), new NodeResources(0.25, 1.32, 10, 0.3), - new Version(8, 173, 5), new NodeResources(0.25, 1.50, 10, 0.3))); + return versioned(clusterSpec, Map.of(new Version(0), new NodeResources(0.25, 1.50, 10, 0.3))); } private NodeResources logserverResources(Architecture architecture) { diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/NodeRepositoryTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/NodeRepositoryTest.java index f45a5cd1c5f..605bf514f03 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/NodeRepositoryTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/NodeRepositoryTest.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision; +import com.yahoo.config.provision.CloudAccount; import com.yahoo.config.provision.NodeType; import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.node.History; @@ -18,6 +19,7 @@ import java.util.function.Predicate; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -50,6 +52,53 @@ public class NodeRepositoryTest { } @Test + public void test_ip_conflicts() { + NodeRepositoryTester tester = new NodeRepositoryTester(); + IP.Config ipConfig = IP.Config.of(Set.of("1.2.3.4", "10.2.3.4"), Set.of("1.2.3.4", "10.2.3.4")); + IP.Config publicIpConfig = IP.Config.of(Set.of("1.2.3.4"), Set.of("1.2.3.4")); + IP.Config privateIpConfig = IP.Config.of(Set.of("10.2.3.4"), Set.of("10.2.3.4")); + + Node host1 = Node.create("id1", ipConfig, "host1", tester.nodeFlavors().getFlavorOrThrow("default"), NodeType.host) + .build(); + tester.nodeRepository().nodes().addNodes(List.of(host1), Agent.system); + + Node publicHost2 = Node.create("id2", publicIpConfig, "host2", tester.nodeFlavors().getFlavorOrThrow("default"), NodeType.host) + .build(); + + Node publicEnclaveHost2 = Node.create("id2", publicIpConfig, "host2", tester.nodeFlavors().getFlavorOrThrow("default"), NodeType.host) + .cloudAccount(CloudAccount.from("gcp:foo-bar-baz")) + .build(); + + // Public IP conflicts inside an account are not allowed + assertEquals("Cannot assign [1.2.3.4] to host2: [1.2.3.4] already assigned to host1", + assertThrows(IllegalArgumentException.class, + () -> tester.nodeRepository().nodes().addNodes(List.of(publicHost2), Agent.system)) + .getMessage()); + + // Public IP conflicts across accounts are not allowed + assertEquals("Cannot assign [1.2.3.4] to host2: [1.2.3.4] already assigned to host1", + assertThrows(IllegalArgumentException.class, + () -> tester.nodeRepository().nodes().addNodes(List.of(publicEnclaveHost2), Agent.system)) + .getMessage()); + + Node privateHost2 = Node.create("id2", privateIpConfig, "host2", tester.nodeFlavors().getFlavorOrThrow("default"), NodeType.host) + .build(); + + Node privateEnclaveHost2 = Node.create("id2", privateIpConfig, "host2", tester.nodeFlavors().getFlavorOrThrow("default"), NodeType.host) + .cloudAccount(CloudAccount.from("gcp:foo-bar-baz")) + .build(); + + // Private IP conflicts inside accounts are not allowed + assertEquals("Cannot assign [10.2.3.4] to host2: [10.2.3.4] already assigned to host1", + assertThrows(IllegalArgumentException.class, + () -> tester.nodeRepository().nodes().addNodes(List.of(privateHost2), Agent.system)) + .getMessage()); + + // Private IP conflicts across accounts are allowed + tester.nodeRepository().nodes().addNodes(List.of(privateEnclaveHost2), Agent.system); + } + + @Test public void only_allow_docker_containers_remove_in_ready() { NodeRepositoryTester tester = new NodeRepositoryTester(); tester.addHost("id1", "host1", "docker", NodeType.tenant); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/NodeRepositoryTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/NodeRepositoryTester.java index 4d0b3e75740..00c4d95b0da 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/NodeRepositoryTester.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/NodeRepositoryTester.java @@ -55,6 +55,7 @@ public class NodeRepositoryTester { public NodeRepository nodeRepository() { return nodeRepository; } public MockCurator curator() { return curator; } + public NodeFlavors nodeFlavors() { return nodeFlavors; } public List<Node> getNodes(NodeType type, Node.State ... inState) { return nodeRepository.nodes().list(inState).nodeType(type).asList(); diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt index 8959a2dd2e0..07045684d6e 100644 --- a/searchlib/CMakeLists.txt +++ b/searchlib/CMakeLists.txt @@ -92,7 +92,6 @@ vespa_define_module( src/tests/attribute/postinglist src/tests/attribute/postinglistattribute src/tests/attribute/raw_attribute - src/tests/attribute/raw_buffer_type_mapper src/tests/attribute/reference_attribute src/tests/attribute/save_target src/tests/attribute/searchable diff --git a/searchlib/src/tests/attribute/attribute_test.cpp b/searchlib/src/tests/attribute/attribute_test.cpp index 870562355d1..a78dbabe4e3 100644 --- a/searchlib/src/tests/attribute/attribute_test.cpp +++ b/searchlib/src/tests/attribute/attribute_test.cpp @@ -1089,8 +1089,8 @@ AttributeTest::testArray() { AttributePtr ptr = createAttribute("a-int32", Config(BasicType::INT32, CollectionType::ARRAY)); ptr->updateStat(true); - EXPECT_EQ(495664u, ptr->getStatus().getAllocated()); - EXPECT_EQ(487904u, ptr->getStatus().getUsed()); + EXPECT_EQ(297952u, ptr->getStatus().getAllocated()); + EXPECT_EQ(256092u, ptr->getStatus().getUsed()); addDocs(ptr, numDocs); testArray<IntegerAttribute, AttributeVector::largeint_t>(ptr, values); } @@ -1099,8 +1099,8 @@ AttributeTest::testArray() cfg.setFastSearch(true); AttributePtr ptr = createAttribute("flags", cfg); ptr->updateStat(true); - EXPECT_EQ(495664u, ptr->getStatus().getAllocated()); - EXPECT_EQ(487904u, ptr->getStatus().getUsed()); + EXPECT_EQ(297952u, ptr->getStatus().getAllocated()); + EXPECT_EQ(256092u, ptr->getStatus().getUsed()); addDocs(ptr, numDocs); testArray<IntegerAttribute, AttributeVector::largeint_t>(ptr, values); } @@ -1109,8 +1109,8 @@ AttributeTest::testArray() cfg.setFastSearch(true); AttributePtr ptr = createAttribute("a-fs-int32", cfg); ptr->updateStat(true); - EXPECT_EQ(852300u, ptr->getStatus().getAllocated()); - EXPECT_EQ(589556u, ptr->getStatus().getUsed()); + EXPECT_EQ(654588u, ptr->getStatus().getAllocated()); + EXPECT_EQ(357744u, ptr->getStatus().getUsed()); addDocs(ptr, numDocs); testArray<IntegerAttribute, AttributeVector::largeint_t>(ptr, values); } @@ -1128,8 +1128,8 @@ AttributeTest::testArray() cfg.setFastSearch(true); AttributePtr ptr = createAttribute("a-fs-float", cfg); ptr->updateStat(true); - EXPECT_EQ(852300u, ptr->getStatus().getAllocated()); - EXPECT_EQ(589556u, ptr->getStatus().getUsed()); + EXPECT_EQ(654588u, ptr->getStatus().getAllocated()); + EXPECT_EQ(357744u, ptr->getStatus().getUsed()); addDocs(ptr, numDocs); testArray<FloatingPointAttribute, double>(ptr, values); } @@ -1140,8 +1140,8 @@ AttributeTest::testArray() { AttributePtr ptr = createAttribute("a-string", Config(BasicType::STRING, CollectionType::ARRAY)); ptr->updateStat(true); - EXPECT_EQ(607968u + sizeof_large_string_entry, ptr->getStatus().getAllocated()); - EXPECT_EQ(540748u + sizeof_large_string_entry, ptr->getStatus().getUsed()); + EXPECT_EQ(410256u + sizeof_large_string_entry, ptr->getStatus().getAllocated()); + EXPECT_EQ(308936u + sizeof_large_string_entry, ptr->getStatus().getUsed()); addDocs(ptr, numDocs); testArray<StringAttribute, string>(ptr, values); } @@ -1150,8 +1150,8 @@ AttributeTest::testArray() cfg.setFastSearch(true); AttributePtr ptr = createAttribute("afs-string", cfg); ptr->updateStat(true); - EXPECT_EQ(858176u + sizeof_large_string_entry, ptr->getStatus().getAllocated()); - EXPECT_EQ(592480u + sizeof_large_string_entry, ptr->getStatus().getUsed()); + EXPECT_EQ(660464u + sizeof_large_string_entry, ptr->getStatus().getAllocated()); + EXPECT_EQ(360668u + sizeof_large_string_entry, ptr->getStatus().getUsed()); addDocs(ptr, numDocs); testArray<StringAttribute, string>(ptr, values); } diff --git a/searchlib/src/tests/attribute/raw_buffer_type_mapper/CMakeLists.txt b/searchlib/src/tests/attribute/raw_buffer_type_mapper/CMakeLists.txt deleted file mode 100644 index c860770536d..00000000000 --- a/searchlib/src/tests/attribute/raw_buffer_type_mapper/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -vespa_add_executable(searchlib_raw_buffer_type_mapper_test_app TEST - SOURCES - raw_buffer_type_mapper_test.cpp - DEPENDS - searchlib - GTest::GTest -) -vespa_add_test(NAME searchlib_raw_buffer_type_mapper_test_app COMMAND searchlib_raw_buffer_type_mapper_test_app) diff --git a/searchlib/src/tests/attribute/raw_buffer_type_mapper/raw_buffer_type_mapper_test.cpp b/searchlib/src/tests/attribute/raw_buffer_type_mapper/raw_buffer_type_mapper_test.cpp deleted file mode 100644 index 74ec839670e..00000000000 --- a/searchlib/src/tests/attribute/raw_buffer_type_mapper/raw_buffer_type_mapper_test.cpp +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include <vespa/searchlib/attribute/raw_buffer_type_mapper.h> -#include <vespa/vespalib/gtest/gtest.h> - -using search::attribute::RawBufferTypeMapper; - -constexpr double default_grow_factor = 1.03; - -class RawBufferTypeMapperTest : public testing::Test -{ -protected: - RawBufferTypeMapper _mapper; - RawBufferTypeMapperTest(); - ~RawBufferTypeMapperTest() override; - std::vector<size_t> get_array_sizes(uint32_t num_array_sizes); - std::vector<size_t> get_large_array_sizes(uint32_t num_large_arrays); - void select_type_ids(std::vector<size_t> array_sizes); - void setup_mapper(uint32_t max_small_buffer_type_id, double grow_factor); - static uint32_t calc_max_small_array_type_id(double grow_factor); -}; - -RawBufferTypeMapperTest::RawBufferTypeMapperTest() - : testing::Test(), - _mapper(5, default_grow_factor) -{ -} - -RawBufferTypeMapperTest::~RawBufferTypeMapperTest() = default; - -void -RawBufferTypeMapperTest::setup_mapper(uint32_t max_small_buffer_type_id, double grow_factor) -{ - _mapper = RawBufferTypeMapper(max_small_buffer_type_id, grow_factor); -} - -std::vector<size_t> -RawBufferTypeMapperTest::get_array_sizes(uint32_t num_array_sizes) -{ - std::vector<size_t> array_sizes; - for (uint32_t type_id = 1; type_id <= num_array_sizes; ++type_id) { - array_sizes.emplace_back(_mapper.get_array_size(type_id)); - } - return array_sizes; -} - -std::vector<size_t> -RawBufferTypeMapperTest::get_large_array_sizes(uint32_t num_large_array_sizes) -{ - setup_mapper(num_large_array_sizes * 100, default_grow_factor); - std::vector<size_t> result; - for (uint32_t i = 0; i < num_large_array_sizes; ++i) { - uint32_t type_id = (i + 1) * 100; - auto array_size = _mapper.get_array_size(type_id); - result.emplace_back(array_size); - EXPECT_EQ(type_id, _mapper.get_type_id(array_size)); - EXPECT_EQ(type_id, _mapper.get_type_id(array_size - 1)); - if (i + 1 == num_large_array_sizes) { - EXPECT_EQ(0u, _mapper.get_type_id(array_size + 1)); - } else { - EXPECT_EQ(type_id + 1, _mapper.get_type_id(array_size + 1)); - } - } - return result; -} - -void -RawBufferTypeMapperTest::select_type_ids(std::vector<size_t> array_sizes) -{ - uint32_t type_id = 0; - for (auto array_size : array_sizes) { - ++type_id; - EXPECT_EQ(type_id, _mapper.get_type_id(array_size)); - EXPECT_EQ(type_id, _mapper.get_type_id(array_size - 1)); - if (array_size == array_sizes.back()) { - // Fallback to indirect storage, using type id 0 - EXPECT_EQ(0u, _mapper.get_type_id(array_size + 1)); - } else { - EXPECT_EQ(type_id + 1, _mapper.get_type_id(array_size + 1)); - } - } -} - -uint32_t -RawBufferTypeMapperTest::calc_max_small_array_type_id(double grow_factor) -{ - RawBufferTypeMapper mapper(1000, grow_factor); - return mapper.get_max_small_array_type_id(1000); -} - -TEST_F(RawBufferTypeMapperTest, array_sizes_are_calculated) -{ - EXPECT_EQ((std::vector<size_t>{8, 12, 16, 20, 24}), get_array_sizes(5)); -} - -TEST_F(RawBufferTypeMapperTest, type_ids_are_selected) -{ - select_type_ids({8, 12, 16, 20, 24}); -} - -TEST_F(RawBufferTypeMapperTest, large_arrays_grows_exponentially) -{ - EXPECT_EQ((std::vector<size_t>{1148, 22796, 438572, 8429384}), get_large_array_sizes(4)); -} - -TEST_F(RawBufferTypeMapperTest, avoid_array_size_overflow) -{ - EXPECT_EQ(29, calc_max_small_array_type_id(2.0)); - EXPECT_EQ(379, calc_max_small_array_type_id(1.05)); - EXPECT_EQ(468, calc_max_small_array_type_id(1.04)); - EXPECT_EQ(610, calc_max_small_array_type_id(1.03)); - EXPECT_EQ(892, calc_max_small_array_type_id(1.02)); -} - -GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp b/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp index 17612f08271..08c0901de01 100644 --- a/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp +++ b/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp @@ -153,7 +153,7 @@ TEST_P(TensorBufferTypeMapperTest, large_arrays_grows_exponentially) TEST_P(TensorBufferTypeMapperTest, avoid_array_size_overflow) { TensorBufferTypeMapper mapper(400, 2.0, &_ops); - EXPECT_GE(30, mapper.get_max_small_array_type_id(1000)); + EXPECT_GE(30, mapper.get_max_type_id(1000)); } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/attribute/CMakeLists.txt b/searchlib/src/vespa/searchlib/attribute/CMakeLists.txt index 6c1f4871161..896226f005d 100644 --- a/searchlib/src/vespa/searchlib/attribute/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/attribute/CMakeLists.txt @@ -111,7 +111,6 @@ vespa_add_library(searchlib_attribute OBJECT raw_buffer_store.cpp raw_buffer_store_reader.cpp raw_buffer_store_writer.cpp - raw_buffer_type_mapper.cpp raw_multi_value_read_view.cpp readerbase.cpp reference_attribute.cpp diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.cpp b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.cpp index 3c1fc15088f..153f4148a64 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.cpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.cpp @@ -5,6 +5,8 @@ #include "i_enum_store.h" #include <vespa/searchcommon/attribute/multivalue.h> #include <vespa/vespalib/datastore/atomic_entry_ref.h> +#include <vespa/vespalib/datastore/array_store_dynamic_type_mapper.hpp> +#include <vespa/vespalib/datastore/dynamic_array_buffer_type.hpp> #include <vespa/vespalib/datastore/buffer_type.hpp> #include <vespa/vespalib/util/array.hpp> diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h index 4fce64aa762..0725a574aa0 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h @@ -5,6 +5,8 @@ #include "multi_value_mapping_base.h" #include "multi_value_mapping_read_view.h" #include <vespa/vespalib/datastore/array_store.h> +#include <vespa/vespalib/datastore/array_store_dynamic_type_mapper.h> +#include <vespa/vespalib/datastore/dynamic_array_buffer_type.h> #include <vespa/vespalib/util/address_space.h> namespace search::attribute { @@ -19,9 +21,13 @@ public: using MultiValueType = ElemT; using RefType = RefT; using ReadView = MultiValueMappingReadView<ElemT, RefT>; + + static constexpr double array_store_grow_factor = 1.03; + static constexpr uint32_t array_store_max_type_id = 300; private: using ArrayRef = vespalib::ArrayRef<ElemT>; - using ArrayStore = vespalib::datastore::ArrayStore<ElemT, RefT>; + using ArrayStoreTypeMapper = vespalib::datastore::ArrayStoreDynamicTypeMapper<ElemT>; + using ArrayStore = vespalib::datastore::ArrayStore<ElemT, RefT, ArrayStoreTypeMapper>; using generation_t = vespalib::GenerationHandler::generation_t; using ConstArrayRef = vespalib::ConstArrayRef<ElemT>; @@ -70,7 +76,7 @@ public: void set_compaction_spec(vespalib::datastore::CompactionSpec compaction_spec) noexcept { _store.set_compaction_spec(compaction_spec); } - static vespalib::datastore::ArrayStoreConfig optimizedConfigForHugePage(size_t maxSmallArraySize, + static vespalib::datastore::ArrayStoreConfig optimizedConfigForHugePage(size_t max_type_id, size_t hugePageSize, size_t smallPageSize, size_t min_num_entries_for_new_buffer, diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp index ab68bea58cc..99808b11e92 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp @@ -12,7 +12,7 @@ MultiValueMapping<ElemT,RefT>::MultiValueMapping(const vespalib::datastore::Arra const vespalib::GrowStrategy &gs, std::shared_ptr<vespalib::alloc::MemoryAllocator> memory_allocator) : MultiValueMappingBase(gs, ArrayStore::getGenerationHolderLocation(_store), memory_allocator), - _store(storeCfg, std::move(memory_allocator)) + _store(storeCfg, std::move(memory_allocator), ArrayStoreTypeMapper(storeCfg.max_type_id(), array_store_grow_factor)) { } @@ -65,14 +65,15 @@ MultiValueMapping<ElemT, RefT>::getAddressSpaceUsage() const { template <typename ElemT, typename RefT> vespalib::datastore::ArrayStoreConfig -MultiValueMapping<ElemT, RefT>::optimizedConfigForHugePage(size_t maxSmallArraySize, +MultiValueMapping<ElemT, RefT>::optimizedConfigForHugePage(size_t max_type_id, size_t hugePageSize, size_t smallPageSize, size_t min_num_entries_for_new_buffer, float allocGrowFactor, bool enable_free_lists) { - auto result = ArrayStore::optimizedConfigForHugePage(maxSmallArraySize, hugePageSize, smallPageSize, min_num_entries_for_new_buffer, allocGrowFactor); + ArrayStoreTypeMapper mapper(max_type_id, array_store_grow_factor); + auto result = ArrayStore::optimizedConfigForHugePage(max_type_id, mapper, hugePageSize, smallPageSize, min_num_entries_for_new_buffer, allocGrowFactor); result.enable_free_lists(enable_free_lists); return result; } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h index 609989208c3..1f9875133d3 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h @@ -4,6 +4,8 @@ #include <vespa/vespalib/datastore/atomic_entry_ref.h> #include <vespa/vespalib/datastore/array_store.h> +#include <vespa/vespalib/datastore/array_store_dynamic_type_mapper.h> +#include <vespa/vespalib/datastore/dynamic_array_buffer_type.h> #include <vespa/vespalib/util/address_space.h> namespace search::attribute { @@ -16,7 +18,8 @@ class MultiValueMappingReadView { using AtomicEntryRef = vespalib::datastore::AtomicEntryRef; using Indices = vespalib::ConstArrayRef<AtomicEntryRef>; - using ArrayStore = vespalib::datastore::ArrayStore<ElemT, RefT>; + using ArrayStoreTypeMapper = vespalib::datastore::ArrayStoreDynamicTypeMapper<ElemT>; + using ArrayStore = vespalib::datastore::ArrayStore<ElemT, RefT, ArrayStoreTypeMapper>; Indices _indices; const ArrayStore* _store; diff --git a/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp index aea7e57897f..d8ada97fa2c 100644 --- a/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp @@ -25,7 +25,7 @@ MultiValueAttribute<B, M>:: MultiValueAttribute(const vespalib::string &baseFileName, const AttributeVector::Config &cfg) : B(baseFileName, cfg), - _mvMapping(MultiValueMapping::optimizedConfigForHugePage(1023, + _mvMapping(MultiValueMapping::optimizedConfigForHugePage(MultiValueMapping::array_store_max_type_id, vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, vespalib::alloc::MemoryAllocator::PAGE_SIZE, 8 * 1024, diff --git a/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp b/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp index 74894728ff4..cd9e0508344 100644 --- a/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp +++ b/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp @@ -17,53 +17,20 @@ namespace search::attribute { RawBufferStore::RawBufferStore(std::shared_ptr<vespalib::alloc::MemoryAllocator> allocator, uint32_t max_small_buffer_type_id, double grow_factor) : _array_store(ArrayStoreType::optimizedConfigForHugePage(max_small_buffer_type_id, - RawBufferTypeMapper(max_small_buffer_type_id, grow_factor), + TypeMapper(max_small_buffer_type_id, grow_factor), MemoryAllocator::HUGEPAGE_SIZE, MemoryAllocator::PAGE_SIZE, 8_Ki, ALLOC_GROW_FACTOR), - std::move(allocator), RawBufferTypeMapper(max_small_buffer_type_id, grow_factor)) + std::move(allocator), TypeMapper(max_small_buffer_type_id, grow_factor)) { } RawBufferStore::~RawBufferStore() = default; -vespalib::ConstArrayRef<char> -RawBufferStore::get(EntryRef ref) const -{ - auto array = _array_store.get(ref); - uint32_t size = 0; - assert(array.size() >= sizeof(size)); - memcpy(&size, array.data(), sizeof(size)); - assert(array.size() >= sizeof(size) + size); - return {array.data() + sizeof(size), size}; } -EntryRef -RawBufferStore::set(vespalib::ConstArrayRef<char> raw) -{ - uint32_t size = raw.size(); - if (size == 0) { - return EntryRef(); - } - size_t buffer_size = raw.size() + sizeof(size); - auto& mapper = _array_store.get_mapper(); - auto type_id = mapper.get_type_id(buffer_size); - auto array_size = (type_id != 0) ? mapper.get_array_size(type_id) : buffer_size; - assert(array_size >= buffer_size); - auto ref = _array_store.allocate(array_size); - auto buf = _array_store.get_writable(ref); - memcpy(buf.data(), &size, sizeof(size)); - memcpy(buf.data() + sizeof(size), raw.data(), size); - if (array_size > buffer_size) { - memset(buf.data() + buffer_size, 0, array_size - buffer_size); - } - return ref; -} +namespace vespalib::datastore { -std::unique_ptr<vespalib::datastore::ICompactionContext> -RawBufferStore::start_compact(const vespalib::datastore::CompactionStrategy& compaction_strategy) -{ - return _array_store.compact_worst(compaction_strategy); -} +template class ArrayStore<char, EntryRefT<19>, ArrayStoreDynamicTypeMapper<char>>; } diff --git a/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.h b/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.h index bc3c189b329..a3f5b564846 100644 --- a/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.h +++ b/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.h @@ -3,7 +3,8 @@ #pragma once #include <vespa/vespalib/datastore/array_store.h> -#include "raw_buffer_type_mapper.h" +#include <vespa/vespalib/datastore/array_store_dynamic_type_mapper.h> +#include <vespa/vespalib/datastore/dynamic_array_buffer_type.h> namespace search::attribute { @@ -15,20 +16,21 @@ class RawBufferStore { using EntryRef = vespalib::datastore::EntryRef; using RefType = vespalib::datastore::EntryRefT<19>; - using ArrayStoreType = vespalib::datastore::ArrayStore<char, RefType, RawBufferTypeMapper>; + using TypeMapper = vespalib::datastore::ArrayStoreDynamicTypeMapper<char>; + using ArrayStoreType = vespalib::datastore::ArrayStore<char, RefType, TypeMapper>; using generation_t = vespalib::GenerationHandler::generation_t; ArrayStoreType _array_store; public: RawBufferStore(std::shared_ptr<vespalib::alloc::MemoryAllocator> allocator, uint32_t max_small_buffer_type_id, double grow_factor); ~RawBufferStore(); - EntryRef set(vespalib::ConstArrayRef<char> raw); - vespalib::ConstArrayRef<char> get(EntryRef ref) const; + EntryRef set(vespalib::ConstArrayRef<char> raw) { return _array_store.add(raw); }; + vespalib::ConstArrayRef<char> get(EntryRef ref) const { return _array_store.get(ref); } void remove(EntryRef ref) { _array_store.remove(ref); } vespalib::MemoryUsage update_stat(const vespalib::datastore::CompactionStrategy& compaction_strategy) { return _array_store.update_stat(compaction_strategy); } vespalib::AddressSpace get_address_space_usage() const { return _array_store.addressSpaceUsage(); } bool consider_compact() const noexcept { return _array_store.consider_compact(); } - std::unique_ptr<vespalib::datastore::ICompactionContext> start_compact(const vespalib::datastore::CompactionStrategy& compaction_strategy); + std::unique_ptr<vespalib::datastore::ICompactionContext> start_compact(const vespalib::datastore::CompactionStrategy& compaction_strategy) { return _array_store.compact_worst(compaction_strategy); } void reclaim_memory(generation_t oldest_used_gen) { _array_store.reclaim_memory(oldest_used_gen); } void assign_generation(generation_t current_gen) { _array_store.assign_generation(current_gen); } void set_initializing(bool initializing) { _array_store.setInitializing(initializing); } diff --git a/searchlib/src/vespa/searchlib/attribute/raw_buffer_type_mapper.cpp b/searchlib/src/vespa/searchlib/attribute/raw_buffer_type_mapper.cpp deleted file mode 100644 index 29245fb403a..00000000000 --- a/searchlib/src/vespa/searchlib/attribute/raw_buffer_type_mapper.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include "raw_buffer_type_mapper.h" -#include <vespa/vespalib/datastore/aligner.h> -#include <algorithm> -#include <cmath> -#include <limits> - -using vespalib::datastore::Aligner; -using vespalib::datastore::ArrayStoreTypeMapper; - -namespace search::attribute { - -RawBufferTypeMapper::RawBufferTypeMapper() - : ArrayStoreTypeMapper() -{ -} - -RawBufferTypeMapper::RawBufferTypeMapper(uint32_t max_small_buffer_type_id, double grow_factor) - : ArrayStoreTypeMapper() -{ - Aligner<4> aligner; - _array_sizes.reserve(max_small_buffer_type_id + 1); - _array_sizes.emplace_back(0); // type id 0 uses LargeArrayBufferType<char> - size_t array_size = 8u; - for (uint32_t type_id = 1; type_id <= max_small_buffer_type_id; ++type_id) { - if (type_id > 1) { - array_size = std::max(array_size + 4, static_cast<size_t>(std::floor(array_size * grow_factor))); - array_size = aligner.align(array_size); - } - if (array_size > std::numeric_limits<uint32_t>::max()) { - break; - } - _array_sizes.emplace_back(array_size); - } -} - -RawBufferTypeMapper::~RawBufferTypeMapper() = default; - -} diff --git a/searchlib/src/vespa/searchlib/attribute/raw_buffer_type_mapper.h b/searchlib/src/vespa/searchlib/attribute/raw_buffer_type_mapper.h deleted file mode 100644 index a2538fe51e5..00000000000 --- a/searchlib/src/vespa/searchlib/attribute/raw_buffer_type_mapper.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#pragma once - -#include <vespa/vespalib/datastore/array_store_type_mapper.h> - -namespace vespalib::datastore { - -template <typename EntryT> class SmallArrayBufferType; -template <typename EntryT> class LargeArrayBufferType; - -} - -namespace search::attribute { - -/* - * This class provides mapping between type ids and array sizes needed for - * storing a raw value. - */ -class RawBufferTypeMapper : public vespalib::datastore::ArrayStoreTypeMapper -{ -public: - using SmallBufferType = vespalib::datastore::SmallArrayBufferType<char>; - using LargeBufferType = vespalib::datastore::LargeArrayBufferType<char>; - - RawBufferTypeMapper(); - RawBufferTypeMapper(uint32_t max_small_buffer_type_id, double grow_factor); - ~RawBufferTypeMapper(); - size_t get_entry_size(uint32_t type_id) const { return get_array_size(type_id); } -}; - -} diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp index a908ebd7210..a78d9cefc64 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp @@ -12,7 +12,7 @@ using vespalib::datastore::EntryRef; namespace { -constexpr uint32_t max_small_array_type_id = 64; +constexpr uint32_t max_type_id = 64; constexpr size_t min_num_arrays_for_new_buffer = 512_Ki; constexpr float alloc_grow_factor = 0.3; @@ -46,7 +46,7 @@ HnswNodeidMapping::HnswNodeidMapping() : _refs(1), _grow_strategy(16, 1.0, 0, 0), // These are the same parameters as the default in rcuvector.h _nodeid_limit(1), // Starting with nodeid=1 matches that we also start with docid=1. - _nodeids(NodeidStore::optimizedConfigForHugePage(max_small_array_type_id, + _nodeids(NodeidStore::optimizedConfigForHugePage(max_type_id, vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, vespalib::alloc::MemoryAllocator::PAGE_SIZE, min_num_arrays_for_new_buffer, diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index d2892ee4429..c1d7e17b457 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -126,6 +126,7 @@ vespa_define_module( src/tests/net/tls/policy_checking_certificate_verifier src/tests/net/tls/protocol_snooping src/tests/net/tls/transport_options + src/tests/nexus src/tests/nice src/tests/objects/identifiable src/tests/objects/nbostream @@ -149,6 +150,7 @@ vespa_define_module( src/tests/require src/tests/runnable_pair src/tests/rusage + src/tests/rw_spin_lock src/tests/sequencedtaskexecutor src/tests/sha1 src/tests/shared_operation_throttler diff --git a/vespalib/src/tests/datastore/array_store/array_store_test.cpp b/vespalib/src/tests/datastore/array_store/array_store_test.cpp index c9f1230346c..e21674e9436 100644 --- a/vespalib/src/tests/datastore/array_store/array_store_test.cpp +++ b/vespalib/src/tests/datastore/array_store/array_store_test.cpp @@ -74,7 +74,7 @@ struct ArrayStoreTest : public TestT type_mapper_grow_factor(type_mapper_grow_factor_in) {} explicit ArrayStoreTest(const ArrayStoreConfig &storeCfg) - : type_mapper(storeCfg.maxSmallArrayTypeId(), 2.0), + : type_mapper(storeCfg.max_type_id(), 2.0), store(storeCfg, std::make_unique<MemoryAllocatorObserver>(stats), TypeMapperType(type_mapper)), refStore(), generation(1), @@ -290,7 +290,7 @@ TYPED_TEST(NumberStoreTest, control_type_mapper) if constexpr (TestFixture::simple_type_mapper) { GTEST_SKIP() << "Skipping test due to using simple type mapper"; } else { - EXPECT_EQ(3, this->type_mapper.get_max_small_array_type_id(1000)); + EXPECT_EQ(3, this->type_mapper.get_max_type_id(1000)); EXPECT_FALSE(this->type_mapper.is_dynamic_buffer(0)); EXPECT_FALSE(this->type_mapper.is_dynamic_buffer(1)); EXPECT_EQ(1, this->type_mapper.get_array_size(1)); diff --git a/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp b/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp index 16abd065a55..71c1341ae74 100644 --- a/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp +++ b/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp @@ -15,15 +15,15 @@ struct Fixture using EntryRefType = EntryRefT<18>; ArrayStoreConfig cfg; - Fixture(uint32_t maxSmallArrayTypeId, + Fixture(uint32_t max_type_id, const AllocSpec &defaultSpec) - : cfg(maxSmallArrayTypeId, defaultSpec) {} + : cfg(max_type_id, defaultSpec) {} - Fixture(uint32_t maxSmallArrayTypeId, + Fixture(uint32_t max_type_id, size_t hugePageSize, size_t smallPageSize, size_t min_num_entries_for_new_buffer) - : cfg(ArrayStoreConfig::optimizeForHugePage(maxSmallArrayTypeId, + : cfg(ArrayStoreConfig::optimizeForHugePage(max_type_id, [](size_t type_id) noexcept { return type_id * sizeof(int); }, hugePageSize, smallPageSize, EntryRefType::offsetSize(), @@ -55,7 +55,7 @@ constexpr size_t MB = KB * KB; TEST_F("require that default allocation spec is given for all array sizes", Fixture(3, makeSpec(4, 32, 8))) { - EXPECT_EQUAL(3u, f.cfg.maxSmallArrayTypeId()); + EXPECT_EQUAL(3u, f.cfg.max_type_id()); TEST_DO(f.assertSpec(0, makeSpec(4, 32, 8))); TEST_DO(f.assertSpec(1, makeSpec(4, 32, 8))); TEST_DO(f.assertSpec(2, makeSpec(4, 32, 8))); @@ -67,7 +67,7 @@ TEST_F("require that we can generate config optimized for a given huge page", Fi 4 * KB, 8 * KB)) { - EXPECT_EQUAL(1_Ki, f.cfg.maxSmallArrayTypeId()); + EXPECT_EQUAL(1_Ki, f.cfg.max_type_id()); TEST_DO(f.assertSpec(0, 8 * KB)); // large arrays TEST_DO(f.assertSpec(1, 256 * KB)); TEST_DO(f.assertSpec(2, 256 * KB)); diff --git a/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp b/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp index 86b80aaa695..7ead0b97269 100644 --- a/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp +++ b/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp @@ -111,7 +111,7 @@ uint32_t TestBase<ElemT>::calc_max_buffer_type_id(double grow_factor) { ArrayStoreDynamicTypeMapper<ElemT> mapper(1000, grow_factor); - return mapper.get_max_small_array_type_id(1000); + return mapper.get_max_type_id(1000); } using ArrayStoreDynamicTypeMapperCharTest = TestBase<char>; diff --git a/vespalib/src/tests/nexus/CMakeLists.txt b/vespalib/src/tests/nexus/CMakeLists.txt new file mode 100644 index 00000000000..4b1b4bc9c25 --- /dev/null +++ b/vespalib/src/tests/nexus/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_nexus_test_app TEST + SOURCES + nexus_test.cpp + DEPENDS + vespalib + GTest::GTest +) +vespa_add_test(NAME vespalib_nexus_test_app COMMAND vespalib_nexus_test_app) diff --git a/vespalib/src/tests/nexus/nexus_test.cpp b/vespalib/src/tests/nexus/nexus_test.cpp new file mode 100644 index 00000000000..09f913dccd1 --- /dev/null +++ b/vespalib/src/tests/nexus/nexus_test.cpp @@ -0,0 +1,92 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/test/nexus.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/require.h> + +using namespace vespalib::test; + +TEST(NexusTest, run_void_tasks) { + std::atomic<size_t> value = 0; + auto task = [&value](Nexus &) { + value.fetch_add(1, std::memory_order_relaxed); + }; + Nexus ctx(10); + ctx.run(task); + EXPECT_EQ(value, 10); + ctx.run(task); + EXPECT_EQ(value, 20); +} + +TEST(NexusTest, run_value_tasks_select_thread_0) { + std::atomic<size_t> value = 0; + auto task = [&value](Nexus &ctx) { + value.fetch_add(1, std::memory_order_relaxed); + return ctx.thread_id() + 5; + }; + Nexus ctx(10); + EXPECT_EQ(ctx.run(task), 5); + EXPECT_EQ(value, 10); +} + +TEST(NexusTest, run_value_tasks_merge_results) { + std::atomic<size_t> value = 0; + auto task = [&value](Nexus &) { + return value.fetch_add(1, std::memory_order_relaxed) + 1; + }; + Nexus ctx(10); + EXPECT_EQ(ctx.run(task, Nexus::merge_sum()), 55); + EXPECT_EQ(value, 10); +} + +TEST(NexusTest, run_inline_voted_loop) { + // Each thread wants to run a loop <thread_id> times, but the loop + // condition is a vote between all threads. After 3 iterations, + // threads 0,1,2,3 vote to exit while threads 4,5,6,7,8 vote to + // continue. After 4 iterations, threads 0,1,2,3,4 vote to exit + // while threads 5,6,7,8 vote to continue. The result is that all + // threads end up doing the loop exactly 4 times. + auto res = Nexus(9).run([](Nexus &ctx) { + size_t times = 0; + for (size_t i = 0; ctx.vote(i < ctx.thread_id()); ++i) { + ++times; + } + return times; + }, [](auto a, auto b){ EXPECT_EQ(a, b); return a; }); + EXPECT_EQ(res, 4); +} + +TEST(NexusTest, run_return_type_decay) { + int value = 3; + auto task = [&](Nexus &)->int&{ return value; }; + Nexus ctx(3); + auto res = ctx.run(task); + EXPECT_EQ(res, 3); + EXPECT_EQ(std::addressof(value), std::addressof(task(ctx))); + using task_res_t = decltype(task(ctx)); + using run_res_t = decltype(ctx.run(task)); + static_assert(std::same_as<task_res_t, int&>); + static_assert(std::same_as<run_res_t, int>); +} + +TEST(NexusTest, example_multi_threaded_unit_test) { + int a = 0; + int b = 0; + auto work = [&](Nexus &ctx) { + EXPECT_EQ(ctx.num_threads(), 2); + if (ctx.thread_id() == 0) { + a = 5; + ctx.barrier(); + EXPECT_EQ(b, 7); + } else { + b = 7; + ctx.barrier(); + EXPECT_EQ(a, 5); + } + }; + Nexus(2).run(work); + EXPECT_EQ(a, 5); + EXPECT_EQ(b, 7); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/rw_spin_lock/CMakeLists.txt b/vespalib/src/tests/rw_spin_lock/CMakeLists.txt new file mode 100644 index 00000000000..76bcb918ce9 --- /dev/null +++ b/vespalib/src/tests/rw_spin_lock/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_rw_spin_lock_test_app TEST + SOURCES + rw_spin_lock_test.cpp + DEPENDS + vespalib + GTest::GTest +) +vespa_add_test(NAME vespalib_rw_spin_lock_test_app NO_VALGRIND COMMAND vespalib_rw_spin_lock_test_app) diff --git a/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp b/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp new file mode 100644 index 00000000000..50621338d8c --- /dev/null +++ b/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp @@ -0,0 +1,323 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/util/spin_lock.h> +#include <vespa/vespalib/util/rw_spin_lock.h> +#include <vespa/vespalib/util/time.h> +#include <vespa/vespalib/util/classname.h> +#include <vespa/vespalib/test/thread_meets.h> +#include <vespa/vespalib/test/nexus.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <type_traits> +#include <ranges> +#include <random> +#include <array> + +using namespace vespalib; +using namespace vespalib::test; + +bool bench = false; +duration budget = 250ms; +constexpr size_t LOOP_CNT = 4096; +size_t thread_safety_work = 1'000'000; +size_t state_loop = 1; + +//----------------------------------------------------------------------------- + +struct DummyLock { + constexpr DummyLock() noexcept {} + // BasicLockable + constexpr void lock() noexcept {} + constexpr void unlock() noexcept {} + // SharedLockable + constexpr void lock_shared() noexcept {} + [[nodiscard]] constexpr bool try_lock_shared() noexcept { return true; } + constexpr void unlock_shared() noexcept {} + // rw_upgrade_downgrade_lock + [[nodiscard]] constexpr bool try_convert_read_to_write() noexcept { return true; } + constexpr void convert_write_to_read() noexcept {} +}; + +//----------------------------------------------------------------------------- + +struct MyState { + static constexpr size_t SZ = 5; + std::array<std::atomic<size_t>,SZ> state = {0,0,0,0,0}; + std::atomic<size_t> inconsistent_reads = 0; + std::atomic<size_t> expected_writes = 0; + [[nodiscard]] size_t update() { + std::array<size_t,SZ> tmp; + for (size_t i = 0; i < SZ; ++i) { + tmp[i] = state[i].load(std::memory_order_relaxed); + } + for (size_t n = 0; n < state_loop; ++n) { + for (size_t i = 0; i < SZ; ++i) { + state[i].store(tmp[i] + 1, std::memory_order_relaxed); + } + } + return 1; + } + [[nodiscard]] size_t peek() { + size_t my_inconsistent_reads = 0; + std::array<size_t,SZ> tmp; + for (size_t i = 0; i < SZ; ++i) { + tmp[i] = state[i].load(std::memory_order_relaxed); + } + for (size_t n = 0; n < state_loop; ++n) { + for (size_t i = 0; i < SZ; ++i) { + if (state[i].load(std::memory_order_relaxed) != tmp[i]) [[unlikely]] { + ++my_inconsistent_reads; + } + } + } + return my_inconsistent_reads; + } + void commit_inconsistent_reads(size_t n) { + inconsistent_reads.fetch_add(n, std::memory_order_relaxed); + } + void commit_expected_writes(size_t n) { + expected_writes.fetch_add(n, std::memory_order_relaxed); + } + [[nodiscard]] bool check() const { + if (inconsistent_reads > 0) { + return false; + } + for (const auto& value: state) { + if (value != expected_writes) { + return false; + } + } + return true; + } + void report(const char *name) const { + if (check()) { + fprintf(stderr, "%s is thread safe\n", name); + } else { + fprintf(stderr, "%s is not thread safe\n", name); + fprintf(stderr, " inconsistent reads: %zu\n", inconsistent_reads.load()); + fprintf(stderr, " expected %zu, got [%zu,%zu,%zu,%zu,%zu]\n", + expected_writes.load(), state[0].load(), state[1].load(), state[2].load(), state[3].load(), state[4].load()); + } + } +}; + +// random generator used to make per-thread decisions +class Rnd { +private: + std::mt19937 _engine; + std::uniform_int_distribution<int> _dist; +public: + Rnd(uint32_t seed) : _engine(seed), _dist(0,9999) {} + bool operator()(int bp) { return _dist(_engine) < bp; } +}; + +//----------------------------------------------------------------------------- + +template<typename T> +concept basic_lockable = requires(T a) { + { a.lock() } -> std::same_as<void>; + { a.unlock() } -> std::same_as<void>; +}; + +template<typename T> +concept lockable = requires(T a) { + { a.try_lock() } -> std::same_as<bool>; + { a.lock() } -> std::same_as<void>; + { a.unlock() } -> std::same_as<void>; +}; + +template<typename T> +concept shared_lockable = requires(T a) { + { a.try_lock_shared() } -> std::same_as<bool>; + { a.lock_shared() } -> std::same_as<void>; + { a.unlock_shared() } -> std::same_as<void>; +}; + +template<typename T> +concept can_upgrade = requires(std::shared_lock<T> a, std::unique_lock<T> b) { + { try_upgrade(std::move(a)) } -> std::same_as<std::unique_lock<T>>; + { downgrade(std::move(b)) } -> std::same_as<std::shared_lock<T>>; +}; + +//----------------------------------------------------------------------------- + +template <size_t N> +auto run_loop(auto &f) { + static_assert(N % 4 == 0); + for (size_t i = 0; i < N / 4; ++i) { + f(); f(); f(); f(); + } +} + +double measure_ns(auto &work) __attribute__((noinline)); +double measure_ns(auto &work) { + constexpr double factor = LOOP_CNT; + auto t0 = steady_clock::now(); + run_loop<LOOP_CNT>(work); + return count_ns(steady_clock::now() - t0) / factor; +} + +struct BenchmarkResult { + double cost_ns; + double range_ns; + BenchmarkResult() + : cost_ns(std::numeric_limits<double>::max()), range_ns(0.0) {} + BenchmarkResult(double cost_ns_in, double range_ns_in) + : cost_ns(cost_ns_in), range_ns(range_ns_in) {} +}; + +struct Meets { + vespalib::test::ThreadMeets::Avg avg; + vespalib::test::ThreadMeets::Range<double> range; + Meets(size_t num_threads) : avg(num_threads), range(num_threads) {} +}; + +BenchmarkResult benchmark_ns(auto &&work, size_t num_threads = 1) { + Meets meets(num_threads); + auto entry = [&](Nexus &ctx) { + Timer timer; + BenchmarkResult result; + for (bool once_more = true; ctx.vote(once_more); once_more = (timer.elapsed() < budget)) { + auto my_ns = measure_ns(work); + auto cost_ns = meets.avg(my_ns); + auto range_ns = meets.range(my_ns); + if (cost_ns < result.cost_ns) { + result.cost_ns = cost_ns; + result.range_ns = range_ns; + } + } + return result; + }; + return Nexus(num_threads).run(entry); +} + +//----------------------------------------------------------------------------- + +template <typename T> +void estimate_cost() { + T lock; + auto name = getClassName(lock); + static_assert(basic_lockable<T>); + fprintf(stderr, "%s exclusive lock/unlock: %g ns\n", name.c_str(), + benchmark_ns([&lock]{ lock.lock(); lock.unlock(); }).cost_ns); + if constexpr (shared_lockable<T>) { + fprintf(stderr, "%s shared lock/unlock: %g ns\n", name.c_str(), + benchmark_ns([&lock]{ lock.lock_shared(); lock.unlock_shared(); }).cost_ns); + } + if constexpr (can_upgrade<T>) { + auto guard = std::shared_lock(lock); + fprintf(stderr, "%s upgrade/downgrade: %g ns\n", name.c_str(), + benchmark_ns([&lock]{ + assert(lock.try_convert_read_to_write()); + lock.convert_write_to_read(); + }).cost_ns); + } +} + +//----------------------------------------------------------------------------- + +template <typename T> +void thread_safety_loop(Nexus &ctx, T &lock, MyState &state, Meets &meets, int read_bp) { + Rnd rnd(ctx.thread_id()); + size_t write_cnt = 0; + size_t bad_reads = 0; + size_t loop_cnt = thread_safety_work / ctx.num_threads(); + ctx.barrier(); + auto t0 = steady_clock::now(); + for (size_t i = 0; i < loop_cnt; ++i) { + if (rnd(read_bp)) { + if constexpr (shared_lockable<T>) { + std::shared_lock guard(lock); + bad_reads += state.peek(); + } else { + std::lock_guard guard(lock); + bad_reads += state.peek(); + } + } else { + { + std::lock_guard guard(lock); + write_cnt += state.update(); + } + } + } + auto t1 = steady_clock::now(); + ctx.barrier(); + auto t2 = steady_clock::now(); + auto my_ms = count_ns(t1 - t0) / 1'000'000.0; + auto total_ms = count_ns(t2 - t0) / 1'000'000.0; + auto cost_ms = meets.avg(my_ms); + auto range_ms = meets.range(my_ms); + if (ctx.thread_id() == 0) { + fprintf(stderr, "---> %s with %2zu threads (%5d bp r): avg: %10.2f ms, range: %10.2f ms, max: %10.2f ms\n", + getClassName(lock).c_str(), ctx.num_threads(), read_bp, cost_ms, range_ms, total_ms); + } + state.commit_inconsistent_reads(bad_reads); + state.commit_expected_writes(write_cnt); +} + +//----------------------------------------------------------------------------- + +TEST(RWSpinLockTest, different_guards_work_with_rw_spin_lock) { + static_assert(basic_lockable<RWSpinLock>); + static_assert(lockable<RWSpinLock>); + static_assert(shared_lockable<RWSpinLock>); + static_assert(can_upgrade<RWSpinLock>); + RWSpinLock lock; + { auto guard = std::lock_guard(lock); } + { auto guard = std::unique_lock(lock); } + { auto guard = std::shared_lock(lock); } +} + +TEST(RWSpinLockTest, estimate_basic_costs) { + Rnd rnd(123); + MyState state; + fprintf(stderr, " rnd cost: %8.2f ns\n", benchmark_ns([&]{ rnd(50); }).cost_ns); + fprintf(stderr, " peek cost: %8.2f ns\n", benchmark_ns([&]{ (void) state.peek(); }).cost_ns); + fprintf(stderr, "update cost: %8.2f ns\n", benchmark_ns([&]{ (void) state.update(); }).cost_ns); +} + +template <typename T> +void benchmark_lock() { + auto lock = std::make_unique<T>(); + auto state = std::make_unique<MyState>(); + for (size_t bp: {10000, 9999, 5000, 0}) { + for (size_t num_threads: {8, 4, 2, 1}) { + if (bench || (bp == 9999 && num_threads == 8)) { + Meets meets(num_threads); + Nexus(num_threads).run([&](Nexus &ctx) { + thread_safety_loop(ctx, *lock, *state, meets, bp); + }); + } + } + } + state->report(getClassName(*lock).c_str()); + if (!std::same_as<T,DummyLock>) { + EXPECT_TRUE(state->check()); + } +} + +TEST(RWSpinLockTest, benchmark_dummy_lock) { benchmark_lock<DummyLock>(); } +TEST(RWSpinLockTest, benchmark_rw_spin_lock) { benchmark_lock<RWSpinLock>(); } +TEST(RWSpinLockTest, benchmark_shared_mutex) { benchmark_lock<std::shared_mutex>(); } +TEST(RWSpinLockTest, benchmark_mutex) { benchmark_lock<std::mutex>(); } +TEST(RWSpinLockTest, benchmark_spin_lock) { benchmark_lock<SpinLock>(); } + +TEST(RWSpinLockTest, estimate_single_threaded_costs) { + estimate_cost<DummyLock>(); + estimate_cost<SpinLock>(); + estimate_cost<std::mutex>(); + estimate_cost<RWSpinLock>(); + estimate_cost<std::shared_mutex>(); +} + +int main(int argc, char **argv) { + if (argc > 1 && (argv[1] == std::string("bench"))) { + bench = true; + budget = 5s; + state_loop = 1024; + fprintf(stderr, "running in benchmarking mode\n"); + ++argv; + --argc; + } + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp b/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp index dfcba14ba63..910c2d017ba 100644 --- a/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp +++ b/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/vespalib/util/shared_string_repo.h> -#include <vespa/vespalib/util/rendezvous.h> +#include <vespa/vespalib/test/thread_meets.h> #include <vespa/vespalib/util/time.h> #include <vespa/vespalib/util/size_literals.h> #include <vespa/vespalib/util/stringfmt.h> @@ -115,41 +115,8 @@ std::unique_ptr<StringIdVector> make_weak_handles(const Handles &handles) { //----------------------------------------------------------------------------- -struct Avg : Rendezvous<double, double> { - explicit Avg(size_t n) : Rendezvous<double, double>(n) {} - void mingle() override { - double sum = 0; - for (size_t i = 0; i < size(); ++i) { - sum += in(i); - } - double result = sum / size(); - for (size_t i = 0; i < size(); ++i) { - out(i) = result; - } - } - double operator()(double value) { return rendezvous(value); } -}; - -struct Vote : Rendezvous<bool, bool> { - explicit Vote(size_t n) : Rendezvous<bool, bool>(n) {} - void mingle() override { - size_t true_cnt = 0; - size_t false_cnt = 0; - for (size_t i = 0; i < size(); ++i) { - if (in(i)) { - ++true_cnt; - } else { - ++false_cnt; - } - } - bool result = (true_cnt > false_cnt); - for (size_t i = 0; i < size(); ++i) { - out(i) = result; - } - } - [[nodiscard]] size_t num_threads() const { return size(); } - bool operator()(bool flag) { return rendezvous(flag); } -}; +using Avg = vespalib::test::ThreadMeets::Avg; +using Vote = vespalib::test::ThreadMeets::Vote; //----------------------------------------------------------------------------- @@ -174,7 +141,7 @@ struct Fixture { : avg(num_threads), vote(num_threads), work(make_strings(work_size)), direct_work(make_direct_strings(work_size)), start_time(steady_clock::now()) {} ~Fixture() { if (verbose) { - fprintf(stderr, "benchmark results for %zu threads:\n", vote.num_threads()); + fprintf(stderr, "benchmark results for %zu threads:\n", vote.size()); for (const auto &[tag, ms_cost]: time_ms) { fprintf(stderr, " %s: %g ms\n", tag.c_str(), ms_cost); } diff --git a/vespalib/src/tests/spin_lock/spin_lock_test.cpp b/vespalib/src/tests/spin_lock/spin_lock_test.cpp index 78e35a3e8d1..84044bfabcf 100644 --- a/vespalib/src/tests/spin_lock/spin_lock_test.cpp +++ b/vespalib/src/tests/spin_lock/spin_lock_test.cpp @@ -77,8 +77,8 @@ template <typename T> size_t thread_safety_loop(T &lock, MyState &state, size_t state.update(); } } - auto t1 = steady_clock::now(); TEST_BARRIER(); + auto t1 = steady_clock::now(); if (thread_id == 0) { auto t2 = steady_clock::now(); size_t total_ms = count_ms(t2 - t0); diff --git a/vespalib/src/vespa/vespalib/datastore/array_store.h b/vespalib/src/vespa/vespalib/datastore/array_store.h index 31b06b1f869..bff3a5cc5b2 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store.h +++ b/vespalib/src/vespa/vespalib/datastore/array_store.h @@ -26,7 +26,7 @@ namespace vespalib::datastore { * The default EntryRef type uses 19 bits for offset (524288 values) and 13 * bits for buffer id (8192 buffers). * - * Buffer type ids [1,maxSmallArrayTypeId] are used to allocate small + * Buffer type ids [1,max_type_id] are used to allocate small * arrays in datastore buffers. * * The simple type mapper (ArrayStoreSimpleTypeMapper) uses a 1-to-1 @@ -40,7 +40,7 @@ namespace vespalib::datastore { * Buffer type id 0 is used to heap allocate large arrays as * vespalib::Array instances. * - * The max value of maxSmallArrayTypeId is (2^(bufferBits - 3) - 1). + * The max value of max_type_id is (2^(bufferBits - 3) - 1). */ template <typename ElemT, typename RefT = EntryRefT<19>, typename TypeMapperT = ArrayStoreSimpleTypeMapper<ElemT> > class ArrayStore : public ICompactable @@ -74,7 +74,7 @@ public: using DynamicBufferTypeVector = typename check_dynamic_buffer_type_member<TypeMapper>::vector_type; private: uint32_t _largeArrayTypeId; - uint32_t _maxSmallArrayTypeId; + uint32_t _max_type_id; size_t _maxSmallArraySize; DataStoreType _store; TypeMapper _mapper; @@ -193,13 +193,13 @@ public: const TypeMapper& get_mapper() const noexcept { return _mapper; } - static ArrayStoreConfig optimizedConfigForHugePage(uint32_t maxSmallArrayTypeId, + static ArrayStoreConfig optimizedConfigForHugePage(uint32_t max_type_id, size_t hugePageSize, size_t smallPageSize, size_t min_num_entries_for_new_buffer, float allocGrowFactor); - static ArrayStoreConfig optimizedConfigForHugePage(uint32_t maxSmallArrayTypeId, + static ArrayStoreConfig optimizedConfigForHugePage(uint32_t max_type_id, const TypeMapper& mapper, size_t hugePageSize, size_t smallPageSize, diff --git a/vespalib/src/vespa/vespalib/datastore/array_store.hpp b/vespalib/src/vespa/vespalib/datastore/array_store.hpp index 8957e1f60aa..211176b8ad0 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store.hpp +++ b/vespalib/src/vespa/vespalib/datastore/array_store.hpp @@ -35,15 +35,15 @@ ArrayStore<ElemT, RefT, TypeMapperT>::initArrayTypes(const ArrayStoreConfig &cfg { _largeArrayTypeId = _store.addType(&_largeArrayType); assert(_largeArrayTypeId == 0); - _smallArrayTypes.reserve(_maxSmallArrayTypeId); + _smallArrayTypes.reserve(_max_type_id); if constexpr (has_dynamic_buffer_type) { - auto dynamic_buffer_types = _mapper.count_dynamic_buffer_types(_maxSmallArrayTypeId); - _smallArrayTypes.reserve(_maxSmallArrayTypeId - dynamic_buffer_types); + auto dynamic_buffer_types = _mapper.count_dynamic_buffer_types(_max_type_id); + _smallArrayTypes.reserve(_max_type_id - dynamic_buffer_types); _dynamicArrayTypes.reserve(dynamic_buffer_types); } else { - _smallArrayTypes.reserve(_maxSmallArrayTypeId); + _smallArrayTypes.reserve(_max_type_id); } - for (uint32_t type_id = 1; type_id <= _maxSmallArrayTypeId; ++type_id) { + for (uint32_t type_id = 1; type_id <= _max_type_id; ++type_id) { uint32_t act_type_id = _store.addType(initArrayType(cfg, memory_allocator, type_id)); assert(type_id == act_type_id); } @@ -59,8 +59,8 @@ template <typename ElemT, typename RefT, typename TypeMapperT> ArrayStore<ElemT, RefT, TypeMapperT>::ArrayStore(const ArrayStoreConfig &cfg, std::shared_ptr<alloc::MemoryAllocator> memory_allocator, TypeMapper&& mapper) : _largeArrayTypeId(0), - _maxSmallArrayTypeId(cfg.maxSmallArrayTypeId()), - _maxSmallArraySize(mapper.get_array_size(_maxSmallArrayTypeId)), + _max_type_id(cfg.max_type_id()), + _maxSmallArraySize(mapper.get_array_size(_max_type_id)), _store(), _mapper(std::move(mapper)), _smallArrayTypes(), @@ -249,14 +249,14 @@ ArrayStore<ElemT, RefT, TypeMapperT>::bufferState(EntryRef ref) template <typename ElemT, typename RefT, typename TypeMapperT> ArrayStoreConfig -ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t maxSmallArrayTypeId, - size_t hugePageSize, - size_t smallPageSize, - size_t min_num_entries_for_new_buffer, - float allocGrowFactor) +ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_type_id, + size_t hugePageSize, + size_t smallPageSize, + size_t min_num_entries_for_new_buffer, + float allocGrowFactor) { TypeMapper mapper; - return optimizedConfigForHugePage(maxSmallArrayTypeId, + return optimizedConfigForHugePage(max_type_id, mapper, hugePageSize, smallPageSize, @@ -266,14 +266,14 @@ ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t maxSma template <typename ElemT, typename RefT, typename TypeMapperT> ArrayStoreConfig -ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t maxSmallArrayTypeId, +ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_type_id, const TypeMapper& mapper, size_t hugePageSize, size_t smallPageSize, size_t min_num_entries_for_new_buffer, float allocGrowFactor) { - return ArrayStoreConfig::optimizeForHugePage(mapper.get_max_small_array_type_id(maxSmallArrayTypeId), + return ArrayStoreConfig::optimizeForHugePage(mapper.get_max_type_id(max_type_id), [&](uint32_t type_id) noexcept { return mapper.get_entry_size(type_id); }, hugePageSize, smallPageSize, diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp b/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp index ee81938b49d..c7f0b69a85e 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp +++ b/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp @@ -5,11 +5,11 @@ namespace vespalib::datastore { -ArrayStoreConfig::ArrayStoreConfig(uint32_t maxSmallArrayTypeId, const AllocSpec &defaultSpec) +ArrayStoreConfig::ArrayStoreConfig(uint32_t max_type_id, const AllocSpec &defaultSpec) : _allocSpecs(), _enable_free_lists(false) { - for (uint32_t type_id = 0; type_id < (maxSmallArrayTypeId + 1); ++type_id) { + for (uint32_t type_id = 0; type_id < (max_type_id + 1); ++type_id) { _allocSpecs.push_back(defaultSpec); } } @@ -45,7 +45,7 @@ alignToSmallPageSize(size_t value, size_t minLimit, size_t smallPageSize) } ArrayStoreConfig -ArrayStoreConfig::optimizeForHugePage(uint32_t maxSmallArrayTypeId, +ArrayStoreConfig::optimizeForHugePage(uint32_t max_type_id, std::function<size_t(uint32_t)> type_id_to_entry_size, size_t hugePageSize, size_t smallPageSize, @@ -55,7 +55,7 @@ ArrayStoreConfig::optimizeForHugePage(uint32_t maxSmallArrayTypeId, { AllocSpecVector allocSpecs; allocSpecs.emplace_back(0, maxEntryRefOffset, min_num_entries_for_new_buffer, allocGrowFactor); // large array spec; - for (uint32_t type_id = 1; type_id <= maxSmallArrayTypeId; ++type_id) { + for (uint32_t type_id = 1; type_id <= max_type_id; ++type_id) { size_t entry_size = type_id_to_entry_size(type_id); size_t num_entries_for_new_buffer = hugePageSize / entry_size; num_entries_for_new_buffer = capToLimits(num_entries_for_new_buffer, min_num_entries_for_new_buffer, maxEntryRefOffset); diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_config.h b/vespalib/src/vespa/vespalib/datastore/array_store_config.h index d581e5958f0..3b62609d0f1 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store_config.h +++ b/vespalib/src/vespa/vespalib/datastore/array_store_config.h @@ -52,12 +52,12 @@ private: public: /** - * Setup an array store where buffer type ids [1-maxSmallArrayTypeId] are used to allocate small arrays in datastore buffers + * Setup an array store where buffer type ids [1-max_type_id] are used to allocate small arrays in datastore buffers * with the given default allocation spec. Larger arrays are heap allocated. */ - ArrayStoreConfig(uint32_t maxSmallArrayTypeId, const AllocSpec &defaultSpec); + ArrayStoreConfig(uint32_t max_type_id, const AllocSpec &defaultSpec); - uint32_t maxSmallArrayTypeId() const { return _allocSpecs.size() - 1; } + uint32_t max_type_id() const { return _allocSpecs.size() - 1; } const AllocSpec &spec_for_type_id(uint32_t type_id) const; ArrayStoreConfig& enable_free_lists(bool enable) & noexcept { _enable_free_lists = enable; @@ -72,7 +72,7 @@ public: /** * Generate a config that is optimized for the given memory huge page size. */ - static ArrayStoreConfig optimizeForHugePage(uint32_t maxSmallArrayTypeId, + static ArrayStoreConfig optimizeForHugePage(uint32_t max_type_id, std::function<size_t(uint32_t)> type_id_to_entry_size, size_t hugePageSize, size_t smallPageSize, diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_simple_type_mapper.h b/vespalib/src/vespa/vespalib/datastore/array_store_simple_type_mapper.h index e43ee704071..314ef3c8aca 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store_simple_type_mapper.h +++ b/vespalib/src/vespa/vespalib/datastore/array_store_simple_type_mapper.h @@ -24,7 +24,7 @@ public: uint32_t get_type_id(size_t array_size) const noexcept { return array_size; } size_t get_array_size(uint32_t type_id) const noexcept { return type_id; } size_t get_entry_size(uint32_t type_id) const noexcept { return get_array_size(type_id) * sizeof(ElemT); } - static uint32_t get_max_small_array_type_id(uint32_t max_small_array_type_id) noexcept { return max_small_array_type_id; } + static uint32_t get_max_type_id(uint32_t max_type_id) noexcept { return max_type_id; } }; } diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_type_mapper.cpp b/vespalib/src/vespa/vespalib/datastore/array_store_type_mapper.cpp index 520fb4cc4ef..c8a9b7ed2a4 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store_type_mapper.cpp +++ b/vespalib/src/vespa/vespalib/datastore/array_store_type_mapper.cpp @@ -33,10 +33,10 @@ ArrayStoreTypeMapper::get_array_size(uint32_t type_id) const } uint32_t -ArrayStoreTypeMapper::get_max_small_array_type_id(uint32_t max_small_array_type_id) const noexcept +ArrayStoreTypeMapper::get_max_type_id(uint32_t max_type_id) const noexcept { auto clamp_type_id = _array_sizes.size() - 1; - return (clamp_type_id < max_small_array_type_id) ? clamp_type_id : max_small_array_type_id; + return (clamp_type_id < max_type_id) ? clamp_type_id : max_type_id; } } diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_type_mapper.h b/vespalib/src/vespa/vespalib/datastore/array_store_type_mapper.h index a73b6ef2e97..c7b57f73259 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store_type_mapper.h +++ b/vespalib/src/vespa/vespalib/datastore/array_store_type_mapper.h @@ -25,7 +25,7 @@ public: uint32_t get_type_id(size_t array_size) const; size_t get_array_size(uint32_t type_id) const; - uint32_t get_max_small_array_type_id(uint32_t max_small_array_type_id) const noexcept; + uint32_t get_max_type_id(uint32_t max_type_id) const noexcept; }; } diff --git a/vespalib/src/vespa/vespalib/test/CMakeLists.txt b/vespalib/src/vespa/vespalib/test/CMakeLists.txt index a60eb15a4d4..02ce1ba3416 100644 --- a/vespalib/src/vespa/vespalib/test/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/test/CMakeLists.txt @@ -3,6 +3,7 @@ vespa_add_library(vespalib_vespalib_test OBJECT SOURCES make_tls_options_for_testing.cpp memory_allocator_observer.cpp + nexus.cpp peer_policy_utils.cpp thread_meets.cpp time_tracer.cpp diff --git a/vespalib/src/vespa/vespalib/test/nexus.cpp b/vespalib/src/vespa/vespalib/test/nexus.cpp new file mode 100644 index 00000000000..b5d7b194576 --- /dev/null +++ b/vespalib/src/vespa/vespalib/test/nexus.cpp @@ -0,0 +1,15 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "nexus.h" + +namespace vespalib::test { + +size_t & +Nexus::my_thread_id() { + thread_local size_t thread_id = invalid_thread_id; + return thread_id; +} + +Nexus::~Nexus() = default; + +} diff --git a/vespalib/src/vespa/vespalib/test/nexus.h b/vespalib/src/vespa/vespalib/test/nexus.h new file mode 100644 index 00000000000..aeb9337b975 --- /dev/null +++ b/vespalib/src/vespa/vespalib/test/nexus.h @@ -0,0 +1,84 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "thread_meets.h" +#include <vespa/vespalib/util/thread.h> +#include <vespa/vespalib/util/require.h> +#include <optional> +#include <variant> + +namespace vespalib::test { + +class Nexus; +template <typename T> +concept nexus_thread_entry = requires(Nexus &ctx, T &&entry) { + entry(ctx); +}; + +/** + * Utility intended to make it easier to write multi-threaded code for + * testing and benchmarking. + **/ +class Nexus +{ +private: + vespalib::test::ThreadMeets::Vote _vote; + static size_t &my_thread_id(); +public: + constexpr static size_t invalid_thread_id = -1; + Nexus(size_t num_threads) noexcept : _vote(num_threads) {} + size_t num_threads() const noexcept { return _vote.size(); } + size_t thread_id() const noexcept { return my_thread_id(); } + bool vote(bool my_vote) { return _vote(my_vote); } + void barrier() { REQUIRE_EQ(_vote(true), true); } + struct select_thread_0 {}; + constexpr static auto merge_sum() { return [](auto a, auto b){ return a + b; }; } + auto run(auto &&entry, auto &&merge) requires nexus_thread_entry<decltype(entry)> { + ThreadPool pool; + using result_t = std::decay_t<decltype(entry(std::declval<Nexus&>()))>; + constexpr bool is_void = std::same_as<result_t, void>; + using stored_t = std::conditional<is_void, std::monostate, result_t>::type; + std::mutex lock; + std::optional<stored_t> result; + auto handle_result = [&](stored_t thread_result) noexcept { + if constexpr (std::same_as<std::decay_t<decltype(merge)>,select_thread_0>) { + if (thread_id() == 0) { + result = std::move(thread_result); + } + } else { + std::lock_guard guard(lock); + if (result.has_value()) { + result = merge(std::move(result).value(), + std::move(thread_result)); + } else { + result = std::move(thread_result); + } + } + }; + auto thread_main = [&](size_t thread_id) noexcept { + size_t old_thread_id = my_thread_id(); + my_thread_id() = thread_id; + if constexpr (is_void) { + entry(*this); + } else { + handle_result(entry(*this)); + } + my_thread_id() = old_thread_id; + }; + for (size_t i = 1; i < num_threads(); ++i) { + pool.start([i,&thread_main]() noexcept { thread_main(i); }); + } + thread_main(0); + pool.join(); + if constexpr (!is_void) { + return std::move(result).value(); + } + } + auto run(auto &&entry) requires nexus_thread_entry<decltype(entry)> { + return run(std::forward<decltype(entry)>(entry), select_thread_0{}); + } + ~Nexus(); +}; + +} diff --git a/vespalib/src/vespa/vespalib/test/thread_meets.cpp b/vespalib/src/vespa/vespalib/test/thread_meets.cpp index 9d23e0eab28..607179c53f9 100644 --- a/vespalib/src/vespa/vespalib/test/thread_meets.cpp +++ b/vespalib/src/vespa/vespalib/test/thread_meets.cpp @@ -9,4 +9,35 @@ ThreadMeets::Nop::mingle() { } +void +ThreadMeets::Avg::mingle() +{ + double sum = 0; + for (size_t i = 0; i < size(); ++i) { + sum += in(i); + } + double result = sum / size(); + for (size_t i = 0; i < size(); ++i) { + out(i) = result; + } +} + +void +ThreadMeets::Vote::mingle() +{ + size_t true_cnt = 0; + size_t false_cnt = 0; + for (size_t i = 0; i < size(); ++i) { + if (in(i)) { + ++true_cnt; + } else { + ++false_cnt; + } + } + bool result = (true_cnt > false_cnt); + for (size_t i = 0; i < size(); ++i) { + out(i) = result; + } +} + } diff --git a/vespalib/src/vespa/vespalib/test/thread_meets.h b/vespalib/src/vespa/vespalib/test/thread_meets.h index 62ca7779935..7ef4dcb9921 100644 --- a/vespalib/src/vespa/vespalib/test/thread_meets.h +++ b/vespalib/src/vespa/vespalib/test/thread_meets.h @@ -12,10 +12,67 @@ namespace vespalib::test { struct ThreadMeets { // can be used as a simple thread barrier struct Nop : vespalib::Rendezvous<bool,bool> { - Nop(size_t N) : vespalib::Rendezvous<bool,bool>(N) {} + explicit Nop(size_t N) : vespalib::Rendezvous<bool,bool>(N) {} void operator()() { rendezvous(false); } void mingle() override; }; + // calculate the average value across threads + struct Avg : Rendezvous<double, double> { + explicit Avg(size_t n) : Rendezvous<double, double>(n) {} + double operator()(double value) { return rendezvous(value); } + void mingle() override; + }; + // threads vote for true/false, majority wins (false on tie) + struct Vote : Rendezvous<bool, bool> { + explicit Vote(size_t n) : Rendezvous<bool, bool>(n) {} + bool operator()(bool flag) { return rendezvous(flag); } + void mingle() override; + }; + // sum of values across all threads + template <typename T> + struct Sum : vespalib::Rendezvous<T,T> { + using vespalib::Rendezvous<T,T>::in; + using vespalib::Rendezvous<T,T>::out; + using vespalib::Rendezvous<T,T>::size; + using vespalib::Rendezvous<T,T>::rendezvous; + explicit Sum(size_t N) : vespalib::Rendezvous<T,T>(N) {} + T operator()(T value) { return rendezvous(value); } + void mingle() override { + T acc{}; + for (size_t i = 0; i < size(); ++i) { + acc += in(i); + } + for (size_t i = 0; i < size(); ++i) { + out(i) = acc; + } + } + }; + // range of values across all threads + template <typename T> + struct Range : vespalib::Rendezvous<T,T> { + using vespalib::Rendezvous<T,T>::in; + using vespalib::Rendezvous<T,T>::out; + using vespalib::Rendezvous<T,T>::size; + using vespalib::Rendezvous<T,T>::rendezvous; + explicit Range(size_t N) : vespalib::Rendezvous<T,T>(N) {} + T operator()(T value) { return rendezvous(value); } + void mingle() override { + T min = in(0); + T max = in(0); + for (size_t i = 1; i < size(); ++i) { + if (in(i) < min) { + min = in(i); + } + if (in(i) > max) { + max = in(i); + } + } + T result = (max - min); + for (size_t i = 0; i < size(); ++i) { + out(i) = result; + } + } + }; // swap values between 2 threads template <typename T> struct Swap : vespalib::Rendezvous<T,T> { @@ -25,8 +82,8 @@ struct ThreadMeets { Swap() : vespalib::Rendezvous<T,T>(2) {} T operator()(T input) { return rendezvous(input); } void mingle() override { - out(1) = in(0); - out(0) = in(1); + out(1) = std::move(in(0)); + out(0) = std::move(in(1)); } }; }; diff --git a/vespalib/src/vespa/vespalib/util/rendezvous.h b/vespalib/src/vespa/vespalib/util/rendezvous.h index 2880f325d96..17a8729c54c 100644 --- a/vespalib/src/vespa/vespalib/util/rendezvous.h +++ b/vespalib/src/vespa/vespalib/util/rendezvous.h @@ -50,14 +50,6 @@ private: protected: /** - * Obtain the number of input and output values to be handled by - * mingle. This function is called by mingle. - * - * @return number of input and output values - **/ - size_t size() const { return _size; } - - /** * Obtain an input parameter. This function is called by mingle. * * @return reference to the appropriate input @@ -87,6 +79,11 @@ public: virtual ~Rendezvous(); /** + * @return number of participants + **/ + size_t size() const { return _size; } + + /** * Called by individual threads to synchronize execution and share * state with the mingle function. * diff --git a/vespalib/src/vespa/vespalib/util/rw_spin_lock.h b/vespalib/src/vespa/vespalib/util/rw_spin_lock.h new file mode 100644 index 00000000000..f2c15dcc0eb --- /dev/null +++ b/vespalib/src/vespa/vespalib/util/rw_spin_lock.h @@ -0,0 +1,189 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <mutex> +#include <shared_mutex> +#include <atomic> +#include <thread> +#include <cassert> +#include <utility> + +namespace vespalib { + +/** + * A reader-writer spin lock implementation. + * + * reader: shared access for any number of readers + * writer: exclusive access for a single writer + * + * valid lock combinations: + * {} + * {N readers} + * {1 writer} + * + * Trying to obtain a write lock will lead to not granting new read + * locks. + * + * This lock is intended for use-cases that involves mostly reading, + * with a little bit of writing. + * + * This class implements the Lockable and SharedLockable named + * requirements from the standard library, making it directly usable + * with std::shared_lock (reader) and std::unique_lock (writer) + * + * There is also some special glue added for lock upgrading and + * downgrading. + * + * NOTE: this implementation is experimental, mostly intended for + * benchmarking and trying to identify use-cases that work with + * rw locks. Upgrade locks that do not block readers might be + * implementet in the future. + **/ +class RWSpinLock { +private: + // [31: num readers][1: pending writer] + // a reader gets the lock by: + // increasing the number of readers while the pending writer bit is not set. + // a writer gets the lock by: + // changing the pending writer bit from 0 to 1 and then + // waiting for the number of readers to become 0 + // an upgrade is successful when: + // a reader is able to obtain the pending writer bit + std::atomic<uint32_t> _state; + + // Convenience function used to check if the pending writer bit is + // set in the given value. + bool has_pending_writer(uint32_t value) noexcept { + return (value & 1); + } + + // Wait for all readers to release their locks. + void wait_for_zero_readers(uint32_t &value) { + while (value != 1) { + std::this_thread::yield(); + value = _state.load(std::memory_order_acquire); + } + } + +public: + RWSpinLock() noexcept : _state(0) { + static_assert(std::atomic<uint32_t>::is_always_lock_free); + } + + // implementation of Lockable named requirement - vvv + + void lock() noexcept { + uint32_t expected = 0; + uint32_t desired = 1; + while (!_state.compare_exchange_weak(expected, desired, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + while (has_pending_writer(expected)) { + std::this_thread::yield(); + expected = _state.load(std::memory_order_relaxed); + } + desired = expected + 1; + } + wait_for_zero_readers(desired); + } + + [[nodiscard]] bool try_lock() noexcept { + uint32_t expected = 0; + return _state.compare_exchange_strong(expected, 1, + std::memory_order_acquire, + std::memory_order_relaxed); + } + + void unlock() noexcept { + _state.store(0, std::memory_order_release); + } + + // implementation of Lockable named requirement - ^^^ + + // implementation of SharedLockable named requirement - vvv + + void lock_shared() noexcept { + uint32_t expected = 0; + uint32_t desired = 2; + while (!_state.compare_exchange_weak(expected, desired, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + while (has_pending_writer(expected)) { + std::this_thread::yield(); + expected = _state.load(std::memory_order_relaxed); + } + desired = expected + 2; + } + } + + [[nodiscard]] bool try_lock_shared() noexcept { + uint32_t expected = 0; + uint32_t desired = 2; + while (!_state.compare_exchange_weak(expected, desired, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + if (has_pending_writer(expected)) { + return false; + } + desired = expected + 2; + } + return true; + } + + void unlock_shared() noexcept { + _state.fetch_sub(2, std::memory_order_release); + } + + // implementation of SharedLockable named requirement - ^^^ + + // try to upgrade a read (shared) lock to a write (unique) lock + bool try_convert_read_to_write() noexcept { + uint32_t expected = 2; + uint32_t desired = 1; + while (!_state.compare_exchange_weak(expected, desired, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + if (has_pending_writer(expected)) { + return false; + } + desired = expected - 1; + } + wait_for_zero_readers(desired); + return true; + } + + // convert a write (unique) lock to a read (shared) lock + void convert_write_to_read() noexcept { + _state.store(2, std::memory_order_release); + } +}; + +template<typename T> +concept rw_upgrade_downgrade_lock = requires(T a, T b) { + { a.try_convert_read_to_write() } -> std::same_as<bool>; + { b.convert_write_to_read() } -> std::same_as<void>; +}; + +template <rw_upgrade_downgrade_lock T> +[[nodiscard]] std::unique_lock<T> try_upgrade(std::shared_lock<T> &&guard) noexcept { + assert(guard.owns_lock()); + if (guard.mutex()->try_convert_read_to_write()) { + return {*guard.release(), std::adopt_lock}; + } else { + return {}; + } +} + +template <rw_upgrade_downgrade_lock T> +[[nodiscard]] std::shared_lock<T> downgrade(std::unique_lock<T> &&guard) noexcept { + assert(guard.owns_lock()); + guard.mutex()->convert_write_to_read(); + return {*guard.release(), std::adopt_lock}; +} + +} |