aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2024-03-15 13:00:37 +0100
committerGitHub <noreply@github.com>2024-03-15 13:00:37 +0100
commitc44558af0a82be9dda1f1ba8bf84337789a906e5 (patch)
tree8720452a861db16698ef39b746baa7affa8b79ce
parent79398debd1275041b60a2354dea1dce53d12c4a6 (diff)
parent87740ab0322a9e534a30440d2f4bda05d3f570d1 (diff)
Merge pull request #30622 from vespa-engine/mpolden/streaming-query
CLI: Support streaming query response
-rw-r--r--client/go/internal/cli/cmd/query.go93
-rw-r--r--client/go/internal/cli/cmd/query_test.go76
-rw-r--r--client/go/internal/ioutil/ioutil.go3
-rw-r--r--client/go/internal/sse/sse.go112
-rw-r--r--client/go/internal/sse/sse_test.go107
5 files changed, 387 insertions, 4 deletions
diff --git a/client/go/internal/cli/cmd/query.go b/client/go/internal/cli/cmd/query.go
index 3e5a60a15df..bddf3af06f9 100644
--- a/client/go/internal/cli/cmd/query.go
+++ b/client/go/internal/cli/cmd/query.go
@@ -5,9 +5,10 @@
package cmd
import (
+ "bufio"
+ "encoding/json"
"fmt"
"io"
- "log"
"net/http"
"net/url"
"strings"
@@ -17,6 +18,7 @@ import (
"github.com/spf13/cobra"
"github.com/vespa-engine/vespa/client/go/internal/curl"
"github.com/vespa-engine/vespa/client/go/internal/ioutil"
+ "github.com/vespa-engine/vespa/client/go/internal/sse"
"github.com/vespa-engine/vespa/client/go/internal/vespa"
)
@@ -25,6 +27,7 @@ func newQueryCmd(cli *CLI) *cobra.Command {
printCurl bool
queryTimeoutSecs int
waitSecs int
+ format string
)
cmd := &cobra.Command{
Use: "query query-parameters",
@@ -39,10 +42,11 @@ can be set by the syntax [parameter-name]=[value].`,
SilenceUsage: true,
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
- return query(cli, args, queryTimeoutSecs, waitSecs, printCurl)
+ return query(cli, args, queryTimeoutSecs, waitSecs, printCurl, format)
},
}
cmd.PersistentFlags().BoolVarP(&printCurl, "verbose", "v", false, "Print the equivalent curl command for the query")
+ cmd.PersistentFlags().StringVarP(&format, "format", "", "human", "Output format. Must be 'human' (human-readable) or 'plain' (no formatting)")
cmd.Flags().IntVarP(&queryTimeoutSecs, "timeout", "T", 10, "Timeout for the query in seconds")
cli.bindWaitFlag(cmd, 0, &waitSecs)
return cmd
@@ -59,7 +63,7 @@ func printCurl(stderr io.Writer, url string, service *vespa.Service) error {
return err
}
-func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) error {
+func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool, format string) error {
target, err := cli.target(targetOptions{})
if err != nil {
return err
@@ -69,6 +73,11 @@ func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) e
if err != nil {
return err
}
+ switch format {
+ case "plain", "human":
+ default:
+ return fmt.Errorf("invalid format: %s", format)
+ }
url, _ := url.Parse(service.BaseURL + "/search/")
urlQuery := url.Query()
for i := 0; i < len(arguments); i++ {
@@ -98,7 +107,9 @@ func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) e
defer response.Body.Close()
if response.StatusCode == 200 {
- log.Print(ioutil.ReaderToJSON(response.Body))
+ if err := printResponse(response.Body, response.Header.Get("Content-Type"), format, cli); err != nil {
+ return err
+ }
} else if response.StatusCode/100 == 4 {
return fmt.Errorf("invalid query: %s\n%s", response.Status, ioutil.ReaderToJSON(response.Body))
} else {
@@ -107,6 +118,80 @@ func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) e
return nil
}
+func printResponse(body io.Reader, contentType, format string, cli *CLI) error {
+ contentType = strings.Split(contentType, ";")[0]
+ if contentType == "text/event-stream" {
+ return printResponseBody(body, printOptions{
+ plainStream: format == "plain",
+ tokenStream: format == "human",
+ }, cli)
+ }
+ return printResponseBody(body, printOptions{parseJSON: format == "human"}, cli)
+}
+
+type printOptions struct {
+ plainStream bool
+ tokenStream bool
+ parseJSON bool
+}
+
+func printResponseBody(body io.Reader, options printOptions, cli *CLI) error {
+ if options.plainStream {
+ scanner := bufio.NewScanner(body)
+ for scanner.Scan() {
+ fmt.Fprintln(cli.Stdout, scanner.Text())
+ }
+ return scanner.Err()
+ } else if options.tokenStream {
+ dec := sse.NewDecoder(body)
+ writingLine := false
+ for {
+ event, err := dec.Decode()
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return err
+ }
+ if event.Name == "token" {
+ if writingLine {
+ fmt.Fprint(cli.Stdout, " ")
+ } else {
+ writingLine = true
+ }
+ var token struct {
+ Value string `json:"token"`
+ }
+ value := event.Data // Optimistic parsing
+ if err := json.Unmarshal([]byte(event.Data), &token); err == nil {
+ value = token.Value
+ }
+ fmt.Fprint(cli.Stdout, value)
+ } else if !event.IsEnd() {
+ if writingLine {
+ fmt.Fprintln(cli.Stdout)
+ }
+ event.Data = ioutil.StringToJSON(event.Data) // Optimistically pretty-print JSON
+ fmt.Fprint(cli.Stdout, event.String())
+ } else {
+ fmt.Fprintln(cli.Stdout)
+ break
+ }
+ }
+ return nil
+ } else if options.parseJSON {
+ text := ioutil.ReaderToJSON(body) // Optimistic, returns body as the raw string if it cannot be parsed to JSON
+ fmt.Fprintln(cli.Stdout, text)
+ return nil
+ } else {
+ b, err := io.ReadAll(body)
+ if err != nil {
+ return err
+ }
+ fmt.Fprintln(cli.Stdout, string(b))
+ return nil
+ }
+}
+
func splitArg(argument string) (string, string) {
parts := strings.SplitN(argument, "=", 2)
if len(parts) < 2 {
diff --git a/client/go/internal/cli/cmd/query_test.go b/client/go/internal/cli/cmd/query_test.go
index 4a35f1530ec..3a2eeba159a 100644
--- a/client/go/internal/cli/cmd/query_test.go
+++ b/client/go/internal/cli/cmd/query_test.go
@@ -5,6 +5,7 @@
package cmd
import (
+ "net/http"
"strconv"
"testing"
@@ -31,6 +32,16 @@ func TestQueryVerbose(t *testing.T) {
assert.Equal(t, "{\n \"query\": \"result\"\n}\n", stdout.String())
}
+func TestQueryUnformatted(t *testing.T) {
+ client := &mock.HTTPClient{}
+ client.NextResponseString(200, "{\"query\":\"result\"}")
+ cli, stdout, _ := newTestCLI(t)
+ cli.httpClient = client
+
+ assert.Nil(t, cli.Run("-t", "http://127.0.0.1:8080", "--format=plain", "query", "select from sources * where title contains 'foo'"))
+ assert.Equal(t, "{\"query\":\"result\"}\n", stdout.String())
+}
+
func TestQueryNonJsonResult(t *testing.T) {
assertQuery(t,
"?timeout=10s&yql=select+from+sources+%2A+where+title+contains+%27foo%27",
@@ -69,6 +80,71 @@ func TestServerError(t *testing.T) {
assertQueryServiceError(t, 501, "server error message")
}
+func TestStreamingQuery(t *testing.T) {
+ body := `
+event: token
+data: {"token": "The"}
+
+event: token
+data: {"token": "Manhattan"}
+
+event: token
+data: {"token": "Project"}
+
+event: end
+`
+ assertStreamingQuery(t, "The Manhattan Project\n", body)
+ assertStreamingQuery(t, body, body, "--format=plain")
+
+ bodyWithError := `
+event: token
+data: {"token": "The"}
+
+event: token
+data: Manhattan
+
+event: error
+data: {"message": "something went wrong"}
+`
+ assertStreamingQuery(t, `The Manhattan
+event: error
+data: {
+ "message": "something went wrong"
+}
+`, bodyWithError)
+ assertStreamingQuery(t, bodyWithError, bodyWithError, "--format=plain")
+}
+
+func assertStreamingQuery(t *testing.T, expectedOutput, body string, args ...string) {
+ t.Helper()
+ client := &mock.HTTPClient{}
+ response := mock.HTTPResponse{Status: 200, Header: make(http.Header)}
+ response.Header.Set("Content-Type", "text/event-stream")
+ response.Body = []byte(body)
+ client.NextResponse(response)
+ cli, stdout, stderr := newTestCLI(t)
+ cli.httpClient = client
+
+ assert.Nil(t, cli.Run(append(args, "-t", "http://127.0.0.1:8080", "query", "select something")...))
+ assert.Equal(t, "", stderr.String())
+ assert.Equal(t, expectedOutput, stdout.String())
+}
+
+func assertStreamingQueryErr(t *testing.T, expectedOut, expectedErr, body string, args ...string) {
+ t.Helper()
+ client := &mock.HTTPClient{}
+ response := mock.HTTPResponse{Status: 200, Header: make(http.Header)}
+ response.Header.Set("Content-Type", "text/event-stream")
+ response.Body = []byte(body)
+ client.NextResponse(response)
+ cli, stdout, stderr := newTestCLI(t)
+ cli.httpClient = client
+
+ assert.NotNil(t, cli.Run(append(args, "-t", "http://127.0.0.1:8080", "query", "select something")...))
+ assert.Equal(t, expectedErr, stderr.String())
+ assert.Equal(t, expectedOut, stdout.String())
+}
+
func assertQuery(t *testing.T, expectedQuery string, query ...string) {
client := &mock.HTTPClient{}
client.NextResponseString(200, "{\"query\":\"result\"}")
diff --git a/client/go/internal/ioutil/ioutil.go b/client/go/internal/ioutil/ioutil.go
index d3a33698d13..0abb9915c39 100644
--- a/client/go/internal/ioutil/ioutil.go
+++ b/client/go/internal/ioutil/ioutil.go
@@ -66,6 +66,9 @@ func ReaderToJSON(reader io.Reader) string {
return prettyJSON.String()
}
+// StringToJSON returns string s as indented JSON.
+func StringToJSON(s string) string { return ReaderToJSON(strings.NewReader(s)) }
+
// AtomicWriteFile atomically writes data to filename.
func AtomicWriteFile(filename string, data []byte) error {
dir := filepath.Dir(filename)
diff --git a/client/go/internal/sse/sse.go b/client/go/internal/sse/sse.go
new file mode 100644
index 00000000000..9a120944eec
--- /dev/null
+++ b/client/go/internal/sse/sse.go
@@ -0,0 +1,112 @@
+package sse
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "strings"
+)
+
+// Event represents a server-sent event. Name and ID are optional fields.
+type Event struct {
+ Name string
+ ID string
+ Data string
+}
+
+// Decoder reads and decodes a server-sent event from an input stream.
+type Decoder struct {
+ scanner *bufio.Scanner
+}
+
+// Decode reads and decodes the next event from the underlying reader.
+func (d *Decoder) Decode() (*Event, error) {
+ // https://www.rfc-editor.org/rfc/rfc8895.html#name-server-push-server-sent-eve
+ var (
+ event Event
+ data strings.Builder
+ lastRead string
+ gotName bool
+ gotID bool
+ gotData bool
+ decoding bool
+ )
+ for d.scanner.Scan() {
+ line := strings.TrimSpace(d.scanner.Text())
+ if line == "" {
+ if decoding {
+ break // Done with event
+ } else {
+ continue // Waiting for first non-empty line
+ }
+ }
+ lastRead = line
+ decoding = true
+ parts := strings.SplitN(line, ": ", 2)
+ if len(parts) < 2 || parts[0] == "" {
+ continue
+ }
+ switch parts[0] {
+ case "event":
+ if gotName {
+ return nil, fmt.Errorf("got more than one event line: last read %q", lastRead)
+ }
+ event.Name = parts[1]
+ gotName = true
+ case "id":
+ if gotID {
+ return nil, fmt.Errorf("got more than one id line: last read %q", lastRead)
+ }
+ event.ID = parts[1]
+ gotID = true
+ case "data":
+ if data.Len() > 0 {
+ data.WriteString(" ")
+ }
+ data.WriteString(parts[1])
+ gotData = true
+ default:
+ return nil, fmt.Errorf("invalid field name %q: last read %q", parts[0], lastRead)
+ }
+ }
+ if err := d.scanner.Err(); err != nil {
+ return nil, err
+ }
+ if !decoding {
+ return nil, io.EOF
+ }
+ if !event.IsEnd() && !gotData {
+ return nil, fmt.Errorf("no data line found for event: last read %q", lastRead)
+ }
+ event.Data = data.String()
+ return &event, nil
+}
+
+// NewDecoder creates a new Decoder that reads from r.
+func NewDecoder(r io.Reader) *Decoder {
+ return &Decoder{scanner: bufio.NewScanner(r)}
+}
+
+// IsEnd returns whether this event indicates that the stream has ended.
+func (e Event) IsEnd() bool { return e.Name == "end" }
+
+// String returns the string representation of event e.
+func (e Event) String() string {
+ var sb strings.Builder
+ if e.Name != "" {
+ sb.WriteString("event: ")
+ sb.WriteString(e.Name)
+ sb.WriteString("\n")
+ }
+ if e.ID != "" {
+ sb.WriteString("id: ")
+ sb.WriteString(e.ID)
+ sb.WriteString("\n")
+ }
+ if e.Data != "" {
+ sb.WriteString("data: ")
+ sb.WriteString(e.Data)
+ sb.WriteString("\n")
+ }
+ return sb.String()
+}
diff --git a/client/go/internal/sse/sse_test.go b/client/go/internal/sse/sse_test.go
new file mode 100644
index 00000000000..0e0d6929c75
--- /dev/null
+++ b/client/go/internal/sse/sse_test.go
@@ -0,0 +1,107 @@
+package sse
+
+import (
+ "errors"
+ "io"
+ "strings"
+ "testing"
+)
+
+func TestDecoder(t *testing.T) {
+ r := strings.NewReader(`
+event: foo
+id: 42
+data: data 1
+
+event: bar
+ignored
+: ignored
+data: data 2
+
+
+event: baz
+data: data 3
+data: data 4
+
+event: bax
+data: data 5
+
+data: data 6
+
+event: end
+`)
+ dec := NewDecoder(r)
+
+ assertDecode(&Event{Name: "foo", ID: "42", Data: "data 1"}, dec, t)
+ assertDecode(&Event{Name: "bar", Data: "data 2"}, dec, t)
+ assertDecode(&Event{Name: "baz", Data: "data 3 data 4"}, dec, t)
+ assertDecode(&Event{Name: "bax", Data: "data 5"}, dec, t)
+ assertDecode(&Event{Data: "data 6"}, dec, t)
+ assertDecode(&Event{Name: "end"}, dec, t)
+ assertDecodeErr(io.EOF, dec, t)
+}
+
+func TestDecoderInvalid(t *testing.T) {
+ r := strings.NewReader(`
+event: foo
+event: bar
+
+event: foo
+id: 42
+
+foo
+
+bad: field
+`)
+ dec := NewDecoder(r)
+ assertDecodeErrString(`got more than one event line: last read "event: bar"`, dec, t)
+ assertDecodeErrString(`no data line found for event: last read "id: 42"`, dec, t)
+ assertDecodeErrString(`no data line found for event: last read "foo"`, dec, t)
+ assertDecodeErrString(`invalid field name "bad": last read "bad: field"`, dec, t)
+}
+
+func TestString(t *testing.T) {
+ assertString(t, "event: foo\ndata: bar\n", Event{Name: "foo", Data: "bar"})
+ assertString(t, "event: foo\nid: 42\ndata: bar\n", Event{Name: "foo", ID: "42", Data: "bar"})
+}
+
+func assertString(t *testing.T, want string, event Event) {
+ t.Helper()
+ got := event.String()
+ if got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+}
+
+func assertDecode(want *Event, dec *Decoder, t *testing.T) {
+ t.Helper()
+ got, err := dec.Decode()
+ if err != nil {
+ t.Fatalf("got error %v, want %+v", err, want)
+ }
+ if got.Name != want.Name {
+ t.Errorf("got Name=%q, want %q", got.Name, want.Name)
+ }
+ if got.ID != want.ID {
+ t.Errorf("got ID=%q, want %q", got.ID, want.ID)
+ }
+ if got.Data != want.Data {
+ t.Errorf("got Data=%q, want %q", got.Data, want.Data)
+ }
+}
+
+func assertDecodeErrString(errMsg string, dec *Decoder, t *testing.T) {
+ t.Helper()
+ assertDecodeErr(errors.New(errMsg), dec, t)
+}
+
+func assertDecodeErr(wantErr error, dec *Decoder, t *testing.T) {
+ t.Helper()
+ _, err := dec.Decode()
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if err.Error() != wantErr.Error() {
+ t.Errorf("got error %q, want %q", err, wantErr)
+ }
+}