From 8780612685274a5c40f0f9537bbc6872bf8c7748 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Thu, 23 Mar 2023 12:03:42 +0100 Subject: Always use HTTP/2 transport when using TLS --- client/go/internal/cli/cmd/feed.go | 15 +++---- client/go/internal/cli/cmd/root.go | 4 +- client/go/internal/mock/http.go | 2 - client/go/internal/util/http.go | 65 +++++++++++++++++++++++-------- client/go/internal/vespa/document/http.go | 10 ++--- client/go/internal/vespa/target.go | 10 ++--- client/go/internal/vespa/target_cloud.go | 10 ++--- client/go/internal/vespa/target_test.go | 6 --- 8 files changed, 69 insertions(+), 53 deletions(-) (limited to 'client/go/internal') diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go index 621676d0353..f273c5aa826 100644 --- a/client/go/internal/cli/cmd/feed.go +++ b/client/go/internal/cli/cmd/feed.go @@ -12,15 +12,13 @@ import ( "github.com/vespa-engine/vespa/client/go/internal/vespa/document" ) -func addFeedFlags(cmd *cobra.Command, maxConnections *int, concurrency *int) { - cmd.PersistentFlags().IntVarP(maxConnections, "max-connections", "N", 8, "Maximum number of HTTP connections to use") +func addFeedFlags(cmd *cobra.Command, concurrency *int) { cmd.PersistentFlags().IntVarP(concurrency, "concurrency", "T", 64, "Number of goroutines to use for dispatching") } func newFeedCmd(cli *CLI) *cobra.Command { var ( - maxConnections int - concurrency int + concurrency int ) cmd := &cobra.Command{ Use: "feed FILE", @@ -45,21 +43,20 @@ newline (JSONL). return err } defer f.Close() - return feed(f, cli, maxConnections, concurrency) + return feed(f, cli, concurrency) }, } - addFeedFlags(cmd, &maxConnections, &concurrency) + addFeedFlags(cmd, &concurrency) return cmd } -func feed(r io.Reader, cli *CLI, maxConnections, concurrency int) error { +func feed(r io.Reader, cli *CLI, concurrency int) error { service, err := documentService(cli) if err != nil { return err } client := document.NewClient(document.ClientOptions{ - BaseURL: service.BaseURL, - MaxConnsPerHost: maxConnections, + BaseURL: service.BaseURL, }, service) dispatcher := document.NewDispatcher(client, concurrency) dec := document.NewDecoder(r) diff --git a/client/go/internal/cli/cmd/root.go b/client/go/internal/cli/cmd/root.go index 5edfd1136e5..58e940d59ef 100644 --- a/client/go/internal/cli/cmd/root.go +++ b/client/go/internal/cli/cmd/root.go @@ -366,7 +366,7 @@ func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Ta return nil, errHint(err, "Deployment to cloud requires a certificate. Try 'vespa auth cert'") } deploymentTLSOptions = vespa.TLSOptions{ - KeyPair: kp.KeyPair, + KeyPair: &kp.KeyPair, CertificateFile: kp.CertificateFile, PrivateKeyFile: kp.PrivateKeyFile, } @@ -377,7 +377,7 @@ func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Ta return nil, errHint(err, "Deployment to hosted requires an Athenz certificate", "Try renewing certificate with 'athenz-user-cert'") } apiTLSOptions = vespa.TLSOptions{ - KeyPair: kp.KeyPair, + KeyPair: &kp.KeyPair, CertificateFile: kp.CertificateFile, PrivateKeyFile: kp.PrivateKeyFile, } diff --git a/client/go/internal/mock/http.go b/client/go/internal/mock/http.go index d1fb4f28327..9c55f2e79bf 100644 --- a/client/go/internal/mock/http.go +++ b/client/go/internal/mock/http.go @@ -58,5 +58,3 @@ func (c *HTTPClient) Do(request *http.Request, timeout time.Duration) (*http.Res }, nil } - -func (c *HTTPClient) Transport() *http.Transport { return &http.Transport{} } diff --git a/client/go/internal/util/http.go b/client/go/internal/util/http.go index b18f9a00c6a..cb35932c8e7 100644 --- a/client/go/internal/util/http.go +++ b/client/go/internal/util/http.go @@ -2,22 +2,22 @@ package util import ( + "bytes" "crypto/tls" "fmt" "net/http" "time" "github.com/vespa-engine/vespa/client/go/internal/build" + "golang.org/x/net/http2" ) type HTTPClient interface { Do(request *http.Request, timeout time.Duration) (response *http.Response, error error) - Transport() *http.Transport } type defaultHTTPClient struct { - client *http.Client - transport *http.Transport + client *http.Client } func (c *defaultHTTPClient) Do(request *http.Request, timeout time.Duration) (response *http.Response, error error) { @@ -31,24 +31,55 @@ func (c *defaultHTTPClient) Do(request *http.Request, timeout time.Duration) (re return c.client.Do(request) } -func (c *defaultHTTPClient) Transport() *http.Transport { return c.transport } - func SetCertificate(client HTTPClient, certificates []tls.Certificate) { - client.Transport().TLSClientConfig = &tls.Config{ - Certificates: certificates, - MinVersion: tls.VersionTLS12, + c, ok := client.(*defaultHTTPClient) + if !ok { + return + } + // Use HTTP/2 transport explicitly. Connection reuse does not work properly when using regular http.Transport, even + // though it upgrades to HTTP/2 automatically + // https://github.com/golang/go/issues/16582 + // https://github.com/golang/go/issues/22091 + var transport *http2.Transport + if _, ok := c.client.Transport.(*http.Transport); ok { + transport = &http2.Transport{} + c.client.Transport = transport + } else if t, ok := c.client.Transport.(*http2.Transport); ok { + transport = t + } else { + panic(fmt.Sprintf("unknown transport type: %T", c.client.Transport)) + } + if ok && !c.hasCertificates(transport.TLSClientConfig, certificates) { + transport.TLSClientConfig = &tls.Config{ + Certificates: certificates, + MinVersion: tls.VersionTLS12, + } } } -func CreateClient(timeout time.Duration) HTTPClient { - transport := http.Transport{ - ForceAttemptHTTP2: true, +func (c *defaultHTTPClient) hasCertificates(tlsConfig *tls.Config, certs []tls.Certificate) bool { + if tlsConfig == nil { + return false } - return &defaultHTTPClient{ - client: &http.Client{ - Timeout: timeout, - Transport: &transport, - }, - transport: &transport, + if len(tlsConfig.Certificates) != len(certs) { + return false } + for i := 0; i < len(certs); i++ { + if len(tlsConfig.Certificates[i].Certificate) != len(certs[i].Certificate) { + return false + } + for j := 0; j < len(certs[i].Certificate); j++ { + if !bytes.Equal(tlsConfig.Certificates[i].Certificate[j], certs[i].Certificate[j]) { + return false + } + } + } + return true +} + +func CreateClient(timeout time.Duration) HTTPClient { + return &defaultHTTPClient{client: &http.Client{ + Timeout: timeout, + Transport: http.DefaultTransport, + }} } diff --git a/client/go/internal/vespa/document/http.go b/client/go/internal/vespa/document/http.go index ad6765aecc8..e86ceb1ebc5 100644 --- a/client/go/internal/vespa/document/http.go +++ b/client/go/internal/vespa/document/http.go @@ -26,11 +26,10 @@ type Client struct { // ClientOptions specifices the configuration options of a feed client. type ClientOptions struct { - MaxConnsPerHost int - BaseURL string - Timeout time.Duration - Route string - TraceLevel *int + BaseURL string + Timeout time.Duration + Route string + TraceLevel *int } type countingReader struct { @@ -51,7 +50,6 @@ func NewClient(options ClientOptions, httpClient util.HTTPClient) *Client { stats: NewStats(), now: time.Now, } - httpClient.Transport().MaxConnsPerHost = options.MaxConnsPerHost return c } diff --git a/client/go/internal/vespa/target.go b/client/go/internal/vespa/target.go index 0e173175720..51861eb12ab 100644 --- a/client/go/internal/vespa/target.go +++ b/client/go/internal/vespa/target.go @@ -74,7 +74,7 @@ type Target interface { // TLSOptions configures the client certificate to use for cloud API or service requests. type TLSOptions struct { - KeyPair tls.Certificate + KeyPair *tls.Certificate CertificateFile string PrivateKeyFile string AthenzDomain string @@ -92,8 +92,8 @@ type LogOptions struct { // Do sends request to this service. Any required authentication happens automatically. func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Response, error) { - if s.TLSOptions.AthenzDomain != "" { - accessToken, err := s.zts.AccessToken(s.TLSOptions.AthenzDomain, s.TLSOptions.KeyPair) + if s.TLSOptions.AthenzDomain != "" && s.TLSOptions.KeyPair != nil { + accessToken, err := s.zts.AccessToken(s.TLSOptions.AthenzDomain, *s.TLSOptions.KeyPair) if err != nil { return nil, err } @@ -105,8 +105,6 @@ func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Respon return s.httpClient.Do(request, timeout) } -func (s *Service) Transport() *http.Transport { return s.httpClient.Transport() } - // Wait polls the health check of this service until it succeeds or timeout passes. func (s *Service) Wait(timeout time.Duration) (int, error) { url := s.BaseURL @@ -118,7 +116,7 @@ func (s *Service) Wait(timeout time.Duration) (int, error) { default: return 0, fmt.Errorf("invalid service: %s", s.Name) } - return waitForOK(s.httpClient, url, &s.TLSOptions.KeyPair, timeout) + return waitForOK(s.httpClient, url, s.TLSOptions.KeyPair, timeout) } func (s *Service) Description() string { diff --git a/client/go/internal/vespa/target_cloud.go b/client/go/internal/vespa/target_cloud.go index 827d6c6a56a..2335d4f3432 100644 --- a/client/go/internal/vespa/target_cloud.go +++ b/client/go/internal/vespa/target_cloud.go @@ -160,8 +160,8 @@ func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, c return nil, fmt.Errorf("unknown service: %s", name) } - if service.TLSOptions.KeyPair.Certificate != nil { - util.SetCertificate(service, []tls.Certificate{service.TLSOptions.KeyPair}) + if service.TLSOptions.KeyPair != nil { + util.SetCertificate(service.httpClient, []tls.Certificate{*service.TLSOptions.KeyPair}) } return service, nil } @@ -275,7 +275,7 @@ func (t *cloudTarget) PrintLog(options LogOptions) error { if options.Follow { timeout = math.MaxInt64 // No timeout } - _, err = wait(t.httpClient, logFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout) + _, err = wait(t.httpClient, logFunc, requestFunc, t.apiOptions.TLSOptions.KeyPair, timeout) return err } @@ -326,7 +326,7 @@ func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { } return true, nil } - _, err = wait(t.httpClient, jobSuccessFunc, requestFunc, &t.apiOptions.TLSOptions.KeyPair, timeout) + _, err = wait(t.httpClient, jobSuccessFunc, requestFunc, t.apiOptions.TLSOptions.KeyPair, timeout) return err } @@ -384,7 +384,7 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { } return true, nil } - if _, err = wait(t.httpClient, endpointFunc, func() *http.Request { return req }, &t.apiOptions.TLSOptions.KeyPair, timeout); err != nil { + if _, err = wait(t.httpClient, endpointFunc, func() *http.Request { return req }, t.apiOptions.TLSOptions.KeyPair, timeout); err != nil { return err } if len(urlsByCluster) == 0 { diff --git a/client/go/internal/vespa/target_test.go b/client/go/internal/vespa/target_test.go index 4f2e361fb39..b9d65f3d8a4 100644 --- a/client/go/internal/vespa/target_test.go +++ b/client/go/internal/vespa/target_test.go @@ -154,11 +154,6 @@ func TestCheckVersion(t *testing.T) { } func createCloudTarget(t *testing.T, url string, logWriter io.Writer) Target { - kp, err := CreateKeyPair() - assert.Nil(t, err) - - x509KeyPair, err := tls.X509KeyPair(kp.Certificate, kp.PrivateKey) - assert.Nil(t, err) apiKey, err := CreateAPIKey() assert.Nil(t, err) @@ -172,7 +167,6 @@ func createCloudTarget(t *testing.T, url string, logWriter io.Writer) Target { Application: ApplicationID{Tenant: "t1", Application: "a1", Instance: "i1"}, Zone: ZoneID{Environment: "dev", Region: "us-north-1"}, }, - TLSOptions: TLSOptions{KeyPair: x509KeyPair}, }, LogOptions{Writer: logWriter}, ) -- cgit v1.2.3