diff options
author | Martin Polden <mpolden@mpolden.no> | 2024-03-14 11:12:48 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2024-03-14 11:15:44 +0100 |
commit | 9be7a2f59e1416889ef44e98a57571aabfac95c9 (patch) | |
tree | 64b97b8ee72c75dad83edccc5a786a2fb3585b31 /client/go | |
parent | 0b29d67b18754c16eb95ccf1d0a5ec6573f1c90d (diff) |
Support streaming query response
Diffstat (limited to 'client/go')
-rw-r--r-- | client/go/internal/cli/cmd/query.go | 92 | ||||
-rw-r--r-- | client/go/internal/cli/cmd/query_test.go | 71 |
2 files changed, 159 insertions, 4 deletions
diff --git a/client/go/internal/cli/cmd/query.go b/client/go/internal/cli/cmd/query.go index 3e5a60a15df..c55c28fb0f6 100644 --- a/client/go/internal/cli/cmd/query.go +++ b/client/go/internal/cli/cmd/query.go @@ -5,9 +5,10 @@ package cmd import ( + "bufio" + "encoding/json" "fmt" "io" - "log" "net/http" "net/url" "strings" @@ -17,6 +18,7 @@ import ( "github.com/spf13/cobra" "github.com/vespa-engine/vespa/client/go/internal/curl" "github.com/vespa-engine/vespa/client/go/internal/ioutil" + "github.com/vespa-engine/vespa/client/go/internal/sse" "github.com/vespa-engine/vespa/client/go/internal/vespa" ) @@ -25,6 +27,7 @@ func newQueryCmd(cli *CLI) *cobra.Command { printCurl bool queryTimeoutSecs int waitSecs int + format string ) cmd := &cobra.Command{ Use: "query query-parameters", @@ -39,10 +42,11 @@ can be set by the syntax [parameter-name]=[value].`, SilenceUsage: true, Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return query(cli, args, queryTimeoutSecs, waitSecs, printCurl) + return query(cli, args, queryTimeoutSecs, waitSecs, printCurl, format) }, } cmd.PersistentFlags().BoolVarP(&printCurl, "verbose", "v", false, "Print the equivalent curl command for the query") + cmd.PersistentFlags().StringVarP(&format, "format", "", "human", "Output format. Must be 'human' (human-readable) or 'plain' (no formatting)") cmd.Flags().IntVarP(&queryTimeoutSecs, "timeout", "T", 10, "Timeout for the query in seconds") cli.bindWaitFlag(cmd, 0, &waitSecs) return cmd @@ -59,7 +63,7 @@ func printCurl(stderr io.Writer, url string, service *vespa.Service) error { return err } -func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) error { +func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool, format string) error { target, err := cli.target(targetOptions{}) if err != nil { return err @@ -69,6 +73,11 @@ func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) e if err != nil { return err } + switch format { + case "plain", "human": + default: + return fmt.Errorf("invalid format: %s", format) + } url, _ := url.Parse(service.BaseURL + "/search/") urlQuery := url.Query() for i := 0; i < len(arguments); i++ { @@ -98,7 +107,9 @@ func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) e defer response.Body.Close() if response.StatusCode == 200 { - log.Print(ioutil.ReaderToJSON(response.Body)) + if err := printResponse(response.Body, response.Header.Get("Content-Type"), format, cli); err != nil { + return err + } } else if response.StatusCode/100 == 4 { return fmt.Errorf("invalid query: %s\n%s", response.Status, ioutil.ReaderToJSON(response.Body)) } else { @@ -107,6 +118,79 @@ func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) e return nil } +func printResponse(body io.Reader, contentType, format string, cli *CLI) error { + contentType = strings.Split(contentType, ";")[0] + if contentType == "text/event-stream" { + return printResponseBody(body, printOptions{ + plainStream: format == "plain", + tokenStream: format == "human", + }, cli) + } + return printResponseBody(body, printOptions{parseJSON: format == "human"}, cli) +} + +type printOptions struct { + plainStream bool + tokenStream bool + parseJSON bool +} + +func printResponseBody(body io.Reader, options printOptions, cli *CLI) error { + if options.plainStream { + scanner := bufio.NewScanner(body) + for scanner.Scan() { + fmt.Fprintln(cli.Stdout, scanner.Text()) + } + return scanner.Err() + } else if options.tokenStream { + dec := sse.NewDecoder(body) + writingLine := false + for { + event, err := dec.Decode() + if err == io.EOF { + break + } else if err != nil { + return err + } + if event.Name == "token" { + if writingLine { + fmt.Fprint(cli.Stdout, " ") + } else { + writingLine = true + } + var token struct { + Value string `json:"token"` + } + value := event.Data // Optimistic parsing + if err := json.Unmarshal([]byte(event.Data), &token); err == nil { + value = token.Value + } + fmt.Fprint(cli.Stdout, value) + } else if !event.IsEnd() { + if writingLine { + fmt.Fprintln(cli.Stdout) + } + return errHint(fmt.Errorf("unknown event type: %q", event.Name), "Event parsing can be disabled with --format=plain") + } else { + fmt.Fprintln(cli.Stdout) + break + } + } + return nil + } else if options.parseJSON { + text := ioutil.ReaderToJSON(body) // Optimistic, returns body as the raw string if it cannot be parsed to JSON + fmt.Fprintln(cli.Stdout, text) + return nil + } else { + b, err := io.ReadAll(body) + if err != nil { + return err + } + fmt.Fprintln(cli.Stdout, string(b)) + return nil + } +} + func splitArg(argument string) (string, string) { parts := strings.SplitN(argument, "=", 2) if len(parts) < 2 { diff --git a/client/go/internal/cli/cmd/query_test.go b/client/go/internal/cli/cmd/query_test.go index 4a35f1530ec..470a5c5d1e2 100644 --- a/client/go/internal/cli/cmd/query_test.go +++ b/client/go/internal/cli/cmd/query_test.go @@ -5,6 +5,7 @@ package cmd import ( + "net/http" "strconv" "testing" @@ -31,6 +32,16 @@ func TestQueryVerbose(t *testing.T) { assert.Equal(t, "{\n \"query\": \"result\"\n}\n", stdout.String()) } +func TestQueryUnformatted(t *testing.T) { + client := &mock.HTTPClient{} + client.NextResponseString(200, "{\"query\":\"result\"}") + cli, stdout, _ := newTestCLI(t) + cli.httpClient = client + + assert.Nil(t, cli.Run("-t", "http://127.0.0.1:8080", "--format=plain", "query", "select from sources * where title contains 'foo'")) + assert.Equal(t, "{\"query\":\"result\"}\n", stdout.String()) +} + func TestQueryNonJsonResult(t *testing.T) { assertQuery(t, "?timeout=10s&yql=select+from+sources+%2A+where+title+contains+%27foo%27", @@ -69,6 +80,66 @@ func TestServerError(t *testing.T) { assertQueryServiceError(t, 501, "server error message") } +func TestStreamingQuery(t *testing.T) { + body := ` +event: token +data: {"token": "The"} + +event: token +data: {"token": "Manhattan"} + +event: token +data: {"token": "Project"} + +event: end +` + assertStreamingQuery(t, "The Manhattan Project\n", body) + assertStreamingQuery(t, body, body, "--format=plain") + + bodyWithError := ` +event: token +data: {"token": "The"} + +event: token +data: Manhattan + +event: error +data: {"message": "something went wrong"} +` + assertStreamingQueryErr(t, "The Manhattan\n", "Error: unknown event type: \"error\"\nHint: Event parsing can be disabled with --format=plain\n", bodyWithError) + assertStreamingQuery(t, bodyWithError, bodyWithError, "--format=plain") +} + +func assertStreamingQuery(t *testing.T, expectedOutput, body string, args ...string) { + t.Helper() + client := &mock.HTTPClient{} + response := mock.HTTPResponse{Status: 200, Header: make(http.Header)} + response.Header.Set("Content-Type", "text/event-stream") + response.Body = []byte(body) + client.NextResponse(response) + cli, stdout, stderr := newTestCLI(t) + cli.httpClient = client + + assert.Nil(t, cli.Run(append(args, "-t", "http://127.0.0.1:8080", "query", "select something")...)) + assert.Equal(t, "", stderr.String()) + assert.Equal(t, expectedOutput, stdout.String()) +} + +func assertStreamingQueryErr(t *testing.T, expectedOut, expectedErr, body string, args ...string) { + t.Helper() + client := &mock.HTTPClient{} + response := mock.HTTPResponse{Status: 200, Header: make(http.Header)} + response.Header.Set("Content-Type", "text/event-stream") + response.Body = []byte(body) + client.NextResponse(response) + cli, stdout, stderr := newTestCLI(t) + cli.httpClient = client + + assert.NotNil(t, cli.Run(append(args, "-t", "http://127.0.0.1:8080", "query", "select something")...)) + assert.Equal(t, expectedErr, stderr.String()) + assert.Equal(t, expectedOut, stdout.String()) +} + func assertQuery(t *testing.T, expectedQuery string, query ...string) { client := &mock.HTTPClient{} client.NextResponseString(200, "{\"query\":\"result\"}") |