diff options
author | Martin Polden <mpolden@mpolden.no> | 2024-06-20 09:58:20 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-20 09:58:20 +0200 |
commit | 7954a76d91d47fae17ec5c03705aef5bc87745b9 (patch) | |
tree | bf20e67eb6761304971f90e976d997fd315b99e9 /client/go | |
parent | e18d33c07a975490b33108f9d9427a4d89728b29 (diff) | |
parent | 653bea036959da9dd5a05282dbf603aa0129790c (diff) |
Merge pull request #31609 from vespa-engine/mpolden/feed-headers
Add header option to feed command
Diffstat (limited to 'client/go')
-rw-r--r-- | client/go/internal/cli/cmd/feed.go | 7 | ||||
-rw-r--r-- | client/go/internal/cli/cmd/query.go | 17 | ||||
-rw-r--r-- | client/go/internal/httputil/httputil.go | 17 | ||||
-rw-r--r-- | client/go/internal/vespa/document/http.go | 10 |
4 files changed, 33 insertions, 18 deletions
diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go index e696333a3f1..6c5df8b3e84 100644 --- a/client/go/internal/cli/cmd/feed.go +++ b/client/go/internal/cli/cmd/feed.go @@ -20,6 +20,7 @@ func addFeedFlags(cli *CLI, cmd *cobra.Command, options *feedOptions) { cmd.PersistentFlags().IntVar(&options.connections, "connections", 8, "The number of connections to use") cmd.PersistentFlags().StringVar(&options.compression, "compression", "auto", `Compression mode to use. Default is "auto" which compresses large documents. Must be "auto", "gzip" or "none"`) cmd.PersistentFlags().IntVar(&options.timeoutSecs, "timeout", 0, "Individual feed operation timeout in seconds. 0 to disable (default 0)") + cmd.Flags().StringSliceVarP(&options.headers, "header", "", nil, "Add a header to all HTTP requests, on the format 'Header: Value'. This can be specified multiple times") cmd.PersistentFlags().IntVar(&options.doomSecs, "deadline", 0, "Exit if this number of seconds elapse without any successful operations. 0 to disable (default 0)") cmd.PersistentFlags().BoolVar(&options.verbose, "verbose", false, "Verbose mode. Print successful operations in addition to errors") cmd.PersistentFlags().StringVar(&options.route, "route", "", `Target Vespa route for feed operations (default "default")`) @@ -49,6 +50,7 @@ type feedOptions struct { speedtestBytes int speedtestSecs int waitSecs int + headers []string memprofile string cpuprofile string @@ -238,12 +240,17 @@ func feed(files []string, options feedOptions, cli *CLI, cmd *cobra.Command) err if err != nil { return err } + header, err := httputil.ParseHeader(options.headers) + if err != nil { + return err + } client, err := document.NewClient(document.ClientOptions{ Compression: compression, Timeout: timeout, Route: options.route, TraceLevel: options.traceLevel, BaseURL: baseURL, + Header: header, Speedtest: options.speedtestBytes > 0, NowFunc: cli.now, }, clients) diff --git a/client/go/internal/cli/cmd/query.go b/client/go/internal/cli/cmd/query.go index 6feead66082..54bbc5fb59a 100644 --- a/client/go/internal/cli/cmd/query.go +++ b/client/go/internal/cli/cmd/query.go @@ -17,6 +17,7 @@ import ( "github.com/fatih/color" "github.com/spf13/cobra" "github.com/vespa-engine/vespa/client/go/internal/curl" + "github.com/vespa-engine/vespa/client/go/internal/httputil" "github.com/vespa-engine/vespa/client/go/internal/ioutil" "github.com/vespa-engine/vespa/client/go/internal/sse" "github.com/vespa-engine/vespa/client/go/internal/vespa" @@ -68,20 +69,6 @@ func printCurl(stderr io.Writer, url string, service *vespa.Service) error { return err } -func parseHeaders(headers []string) (http.Header, error) { - h := make(http.Header) - for _, header := range headers { - kv := strings.SplitN(header, ":", 2) - if len(kv) < 2 { - return nil, fmt.Errorf("invalid header %q: missing colon separator", header) - } - k := kv[0] - v := strings.TrimSpace(kv[1]) - h.Add(k, v) - } - return h, nil -} - func query(cli *CLI, arguments []string, timeoutSecs int, curl bool, format string, headers []string, waiter *Waiter) error { target, err := cli.target(targetOptions{}) if err != nil { @@ -118,7 +105,7 @@ func query(cli *CLI, arguments []string, timeoutSecs int, curl bool, format stri return err } } - header, err := parseHeaders(headers) + header, err := httputil.ParseHeader(headers) if err != nil { return err } diff --git a/client/go/internal/httputil/httputil.go b/client/go/internal/httputil/httputil.go index e1e27de5523..56ac31d93a8 100644 --- a/client/go/internal/httputil/httputil.go +++ b/client/go/internal/httputil/httputil.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/http" + "strings" "time" "github.com/vespa-engine/vespa/client/go/internal/build" @@ -99,3 +100,19 @@ func NewClient(timeout time.Duration) Client { }, } } + +// ParseHeader parses headers slice into a http.Header. Each element in the slice is expected to contain a string on +// the format "Header: Value". +func ParseHeader(headers []string) (http.Header, error) { + h := make(http.Header) + for _, header := range headers { + kv := strings.SplitN(header, ":", 2) + if len(kv) < 2 { + return nil, fmt.Errorf("invalid header %q: missing colon separator", header) + } + k := kv[0] + v := strings.TrimSpace(kv[1]) + h.Add(k, v) + } + return h, nil +} diff --git a/client/go/internal/vespa/document/http.go b/client/go/internal/vespa/document/http.go index 3871ab19edd..80789a208b6 100644 --- a/client/go/internal/vespa/document/http.go +++ b/client/go/internal/vespa/document/http.go @@ -43,6 +43,7 @@ type Client struct { // ClientOptions specifices the configuration options of a feed client. type ClientOptions struct { BaseURL string + Header http.Header Timeout time.Duration Route string TraceLevel int @@ -216,11 +217,14 @@ func (c *Client) prepare(document Document) (*http.Request, *bytes.Buffer, error return pd.request, pd.buf, pd.err } -func newRequest(method, url string, body io.Reader, gzipped bool) (*http.Request, error) { +func (c *Client) newRequest(method, url string, body io.Reader, gzipped bool) (*http.Request, error) { req, err := http.NewRequest(method, url, body) if err != nil { return nil, err } + for k, v := range c.options.Header { + req.Header[k] = v + } req.Header.Set("Content-Type", "application/json; charset=utf-8") if gzipped { req.Header.Set("Content-Encoding", "gzip") @@ -231,7 +235,7 @@ func newRequest(method, url string, body io.Reader, gzipped bool) (*http.Request func (c *Client) createRequest(method, url string, body []byte, buf *bytes.Buffer) (*http.Request, error) { buf.Reset() if len(body) == 0 { - return newRequest(method, url, nil, false) + return c.newRequest(method, url, nil, false) } useGzip := c.options.Compression == CompressionGzip || (c.options.Compression == CompressionAuto && len(body) > 512) var r io.Reader @@ -249,7 +253,7 @@ func (c *Client) createRequest(method, url string, body []byte, buf *bytes.Buffe } else { r = bytes.NewReader(body) } - return newRequest(method, url, r, useGzip) + return c.newRequest(method, url, r, useGzip) } func (c *Client) clientTimeout() time.Duration { |