diff options
9 files changed, 393 insertions, 15 deletions
diff --git a/client/go/internal/cli/cmd/query.go b/client/go/internal/cli/cmd/query.go index 3e5a60a15df..bddf3af06f9 100644 --- a/client/go/internal/cli/cmd/query.go +++ b/client/go/internal/cli/cmd/query.go @@ -5,9 +5,10 @@ package cmd import ( + "bufio" + "encoding/json" "fmt" "io" - "log" "net/http" "net/url" "strings" @@ -17,6 +18,7 @@ import ( "github.com/spf13/cobra" "github.com/vespa-engine/vespa/client/go/internal/curl" "github.com/vespa-engine/vespa/client/go/internal/ioutil" + "github.com/vespa-engine/vespa/client/go/internal/sse" "github.com/vespa-engine/vespa/client/go/internal/vespa" ) @@ -25,6 +27,7 @@ func newQueryCmd(cli *CLI) *cobra.Command { printCurl bool queryTimeoutSecs int waitSecs int + format string ) cmd := &cobra.Command{ Use: "query query-parameters", @@ -39,10 +42,11 @@ can be set by the syntax [parameter-name]=[value].`, SilenceUsage: true, Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return query(cli, args, queryTimeoutSecs, waitSecs, printCurl) + return query(cli, args, queryTimeoutSecs, waitSecs, printCurl, format) }, } cmd.PersistentFlags().BoolVarP(&printCurl, "verbose", "v", false, "Print the equivalent curl command for the query") + cmd.PersistentFlags().StringVarP(&format, "format", "", "human", "Output format. Must be 'human' (human-readable) or 'plain' (no formatting)") cmd.Flags().IntVarP(&queryTimeoutSecs, "timeout", "T", 10, "Timeout for the query in seconds") cli.bindWaitFlag(cmd, 0, &waitSecs) return cmd @@ -59,7 +63,7 @@ func printCurl(stderr io.Writer, url string, service *vespa.Service) error { return err } -func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) error { +func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool, format string) error { target, err := cli.target(targetOptions{}) if err != nil { return err @@ -69,6 +73,11 @@ func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) e if err != nil { return err } + switch format { + case "plain", "human": + default: + return fmt.Errorf("invalid format: %s", format) + } url, _ := url.Parse(service.BaseURL + "/search/") urlQuery := url.Query() for i := 0; i < len(arguments); i++ { @@ -98,7 +107,9 @@ func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) e defer response.Body.Close() if response.StatusCode == 200 { - log.Print(ioutil.ReaderToJSON(response.Body)) + if err := printResponse(response.Body, response.Header.Get("Content-Type"), format, cli); err != nil { + return err + } } else if response.StatusCode/100 == 4 { return fmt.Errorf("invalid query: %s\n%s", response.Status, ioutil.ReaderToJSON(response.Body)) } else { @@ -107,6 +118,80 @@ func query(cli *CLI, arguments []string, timeoutSecs, waitSecs int, curl bool) e return nil } +func printResponse(body io.Reader, contentType, format string, cli *CLI) error { + contentType = strings.Split(contentType, ";")[0] + if contentType == "text/event-stream" { + return printResponseBody(body, printOptions{ + plainStream: format == "plain", + tokenStream: format == "human", + }, cli) + } + return printResponseBody(body, printOptions{parseJSON: format == "human"}, cli) +} + +type printOptions struct { + plainStream bool + tokenStream bool + parseJSON bool +} + +func printResponseBody(body io.Reader, options printOptions, cli *CLI) error { + if options.plainStream { + scanner := bufio.NewScanner(body) + for scanner.Scan() { + fmt.Fprintln(cli.Stdout, scanner.Text()) + } + return scanner.Err() + } else if options.tokenStream { + dec := sse.NewDecoder(body) + writingLine := false + for { + event, err := dec.Decode() + if err == io.EOF { + break + } else if err != nil { + return err + } + if event.Name == "token" { + if writingLine { + fmt.Fprint(cli.Stdout, " ") + } else { + writingLine = true + } + var token struct { + Value string `json:"token"` + } + value := event.Data // Optimistic parsing + if err := json.Unmarshal([]byte(event.Data), &token); err == nil { + value = token.Value + } + fmt.Fprint(cli.Stdout, value) + } else if !event.IsEnd() { + if writingLine { + fmt.Fprintln(cli.Stdout) + } + event.Data = ioutil.StringToJSON(event.Data) // Optimistically pretty-print JSON + fmt.Fprint(cli.Stdout, event.String()) + } else { + fmt.Fprintln(cli.Stdout) + break + } + } + return nil + } else if options.parseJSON { + text := ioutil.ReaderToJSON(body) // Optimistic, returns body as the raw string if it cannot be parsed to JSON + fmt.Fprintln(cli.Stdout, text) + return nil + } else { + b, err := io.ReadAll(body) + if err != nil { + return err + } + fmt.Fprintln(cli.Stdout, string(b)) + return nil + } +} + func splitArg(argument string) (string, string) { parts := strings.SplitN(argument, "=", 2) if len(parts) < 2 { diff --git a/client/go/internal/cli/cmd/query_test.go b/client/go/internal/cli/cmd/query_test.go index 4a35f1530ec..3a2eeba159a 100644 --- a/client/go/internal/cli/cmd/query_test.go +++ b/client/go/internal/cli/cmd/query_test.go @@ -5,6 +5,7 @@ package cmd import ( + "net/http" "strconv" "testing" @@ -31,6 +32,16 @@ func TestQueryVerbose(t *testing.T) { assert.Equal(t, "{\n \"query\": \"result\"\n}\n", stdout.String()) } +func TestQueryUnformatted(t *testing.T) { + client := &mock.HTTPClient{} + client.NextResponseString(200, "{\"query\":\"result\"}") + cli, stdout, _ := newTestCLI(t) + cli.httpClient = client + + assert.Nil(t, cli.Run("-t", "http://127.0.0.1:8080", "--format=plain", "query", "select from sources * where title contains 'foo'")) + assert.Equal(t, "{\"query\":\"result\"}\n", stdout.String()) +} + func TestQueryNonJsonResult(t *testing.T) { assertQuery(t, "?timeout=10s&yql=select+from+sources+%2A+where+title+contains+%27foo%27", @@ -69,6 +80,71 @@ func TestServerError(t *testing.T) { assertQueryServiceError(t, 501, "server error message") } +func TestStreamingQuery(t *testing.T) { + body := ` +event: token +data: {"token": "The"} + +event: token +data: {"token": "Manhattan"} + +event: token +data: {"token": "Project"} + +event: end +` + assertStreamingQuery(t, "The Manhattan Project\n", body) + assertStreamingQuery(t, body, body, "--format=plain") + + bodyWithError := ` +event: token +data: {"token": "The"} + +event: token +data: Manhattan + +event: error +data: {"message": "something went wrong"} +` + assertStreamingQuery(t, `The Manhattan +event: error +data: { + "message": "something went wrong" +} +`, bodyWithError) + assertStreamingQuery(t, bodyWithError, bodyWithError, "--format=plain") +} + +func assertStreamingQuery(t *testing.T, expectedOutput, body string, args ...string) { + t.Helper() + client := &mock.HTTPClient{} + response := mock.HTTPResponse{Status: 200, Header: make(http.Header)} + response.Header.Set("Content-Type", "text/event-stream") + response.Body = []byte(body) + client.NextResponse(response) + cli, stdout, stderr := newTestCLI(t) + cli.httpClient = client + + assert.Nil(t, cli.Run(append(args, "-t", "http://127.0.0.1:8080", "query", "select something")...)) + assert.Equal(t, "", stderr.String()) + assert.Equal(t, expectedOutput, stdout.String()) +} + +func assertStreamingQueryErr(t *testing.T, expectedOut, expectedErr, body string, args ...string) { + t.Helper() + client := &mock.HTTPClient{} + response := mock.HTTPResponse{Status: 200, Header: make(http.Header)} + response.Header.Set("Content-Type", "text/event-stream") + response.Body = []byte(body) + client.NextResponse(response) + cli, stdout, stderr := newTestCLI(t) + cli.httpClient = client + + assert.NotNil(t, cli.Run(append(args, "-t", "http://127.0.0.1:8080", "query", "select something")...)) + assert.Equal(t, expectedErr, stderr.String()) + assert.Equal(t, expectedOut, stdout.String()) +} + func assertQuery(t *testing.T, expectedQuery string, query ...string) { client := &mock.HTTPClient{} client.NextResponseString(200, "{\"query\":\"result\"}") diff --git a/client/go/internal/ioutil/ioutil.go b/client/go/internal/ioutil/ioutil.go index d3a33698d13..0abb9915c39 100644 --- a/client/go/internal/ioutil/ioutil.go +++ b/client/go/internal/ioutil/ioutil.go @@ -66,6 +66,9 @@ func ReaderToJSON(reader io.Reader) string { return prettyJSON.String() } +// StringToJSON returns string s as indented JSON. +func StringToJSON(s string) string { return ReaderToJSON(strings.NewReader(s)) } + // AtomicWriteFile atomically writes data to filename. func AtomicWriteFile(filename string, data []byte) error { dir := filepath.Dir(filename) diff --git a/client/go/internal/sse/sse.go b/client/go/internal/sse/sse.go new file mode 100644 index 00000000000..9a120944eec --- /dev/null +++ b/client/go/internal/sse/sse.go @@ -0,0 +1,112 @@ +package sse + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +// Event represents a server-sent event. Name and ID are optional fields. +type Event struct { + Name string + ID string + Data string +} + +// Decoder reads and decodes a server-sent event from an input stream. +type Decoder struct { + scanner *bufio.Scanner +} + +// Decode reads and decodes the next event from the underlying reader. +func (d *Decoder) Decode() (*Event, error) { + // https://www.rfc-editor.org/rfc/rfc8895.html#name-server-push-server-sent-eve + var ( + event Event + data strings.Builder + lastRead string + gotName bool + gotID bool + gotData bool + decoding bool + ) + for d.scanner.Scan() { + line := strings.TrimSpace(d.scanner.Text()) + if line == "" { + if decoding { + break // Done with event + } else { + continue // Waiting for first non-empty line + } + } + lastRead = line + decoding = true + parts := strings.SplitN(line, ": ", 2) + if len(parts) < 2 || parts[0] == "" { + continue + } + switch parts[0] { + case "event": + if gotName { + return nil, fmt.Errorf("got more than one event line: last read %q", lastRead) + } + event.Name = parts[1] + gotName = true + case "id": + if gotID { + return nil, fmt.Errorf("got more than one id line: last read %q", lastRead) + } + event.ID = parts[1] + gotID = true + case "data": + if data.Len() > 0 { + data.WriteString(" ") + } + data.WriteString(parts[1]) + gotData = true + default: + return nil, fmt.Errorf("invalid field name %q: last read %q", parts[0], lastRead) + } + } + if err := d.scanner.Err(); err != nil { + return nil, err + } + if !decoding { + return nil, io.EOF + } + if !event.IsEnd() && !gotData { + return nil, fmt.Errorf("no data line found for event: last read %q", lastRead) + } + event.Data = data.String() + return &event, nil +} + +// NewDecoder creates a new Decoder that reads from r. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{scanner: bufio.NewScanner(r)} +} + +// IsEnd returns whether this event indicates that the stream has ended. +func (e Event) IsEnd() bool { return e.Name == "end" } + +// String returns the string representation of event e. +func (e Event) String() string { + var sb strings.Builder + if e.Name != "" { + sb.WriteString("event: ") + sb.WriteString(e.Name) + sb.WriteString("\n") + } + if e.ID != "" { + sb.WriteString("id: ") + sb.WriteString(e.ID) + sb.WriteString("\n") + } + if e.Data != "" { + sb.WriteString("data: ") + sb.WriteString(e.Data) + sb.WriteString("\n") + } + return sb.String() +} diff --git a/client/go/internal/sse/sse_test.go b/client/go/internal/sse/sse_test.go new file mode 100644 index 00000000000..0e0d6929c75 --- /dev/null +++ b/client/go/internal/sse/sse_test.go @@ -0,0 +1,107 @@ +package sse + +import ( + "errors" + "io" + "strings" + "testing" +) + +func TestDecoder(t *testing.T) { + r := strings.NewReader(` +event: foo +id: 42 +data: data 1 + +event: bar +ignored +: ignored +data: data 2 + + +event: baz +data: data 3 +data: data 4 + +event: bax +data: data 5 + +data: data 6 + +event: end +`) + dec := NewDecoder(r) + + assertDecode(&Event{Name: "foo", ID: "42", Data: "data 1"}, dec, t) + assertDecode(&Event{Name: "bar", Data: "data 2"}, dec, t) + assertDecode(&Event{Name: "baz", Data: "data 3 data 4"}, dec, t) + assertDecode(&Event{Name: "bax", Data: "data 5"}, dec, t) + assertDecode(&Event{Data: "data 6"}, dec, t) + assertDecode(&Event{Name: "end"}, dec, t) + assertDecodeErr(io.EOF, dec, t) +} + +func TestDecoderInvalid(t *testing.T) { + r := strings.NewReader(` +event: foo +event: bar + +event: foo +id: 42 + +foo + +bad: field +`) + dec := NewDecoder(r) + assertDecodeErrString(`got more than one event line: last read "event: bar"`, dec, t) + assertDecodeErrString(`no data line found for event: last read "id: 42"`, dec, t) + assertDecodeErrString(`no data line found for event: last read "foo"`, dec, t) + assertDecodeErrString(`invalid field name "bad": last read "bad: field"`, dec, t) +} + +func TestString(t *testing.T) { + assertString(t, "event: foo\ndata: bar\n", Event{Name: "foo", Data: "bar"}) + assertString(t, "event: foo\nid: 42\ndata: bar\n", Event{Name: "foo", ID: "42", Data: "bar"}) +} + +func assertString(t *testing.T, want string, event Event) { + t.Helper() + got := event.String() + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func assertDecode(want *Event, dec *Decoder, t *testing.T) { + t.Helper() + got, err := dec.Decode() + if err != nil { + t.Fatalf("got error %v, want %+v", err, want) + } + if got.Name != want.Name { + t.Errorf("got Name=%q, want %q", got.Name, want.Name) + } + if got.ID != want.ID { + t.Errorf("got ID=%q, want %q", got.ID, want.ID) + } + if got.Data != want.Data { + t.Errorf("got Data=%q, want %q", got.Data, want.Data) + } +} + +func assertDecodeErrString(errMsg string, dec *Decoder, t *testing.T) { + t.Helper() + assertDecodeErr(errors.New(errMsg), dec, t) +} + +func assertDecodeErr(wantErr error, dec *Decoder, t *testing.T) { + t.Helper() + _, err := dec.Decode() + if err == nil { + t.Fatal("expected error") + } + if err.Error() != wantErr.Error() { + t.Errorf("got error %q, want %q", err, wantErr) + } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/LocalProvider.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/LocalProvider.java index 75b47fbdf60..da7dfdc7b84 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/LocalProvider.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/LocalProvider.java @@ -17,6 +17,7 @@ import com.yahoo.vespa.model.search.SearchCluster; import java.util.ArrayList; import java.util.LinkedHashSet; import java.util.List; +import java.util.Objects; import java.util.Set; /** @@ -37,7 +38,7 @@ public class LocalProvider extends Provider implements @Override public void getConfig(ClusterConfig.Builder builder) { - assert (searchCluster != null) : "Null search cluster!"; + Objects.requireNonNull(searchCluster, "Null search cluster!"); builder.clusterId(searchCluster.getClusterIndex()); builder.clusterName(searchCluster.getClusterName()); diff --git a/container-search/src/test/java/com/yahoo/prelude/IndexFactsFactory.java b/container-search/src/test/java/com/yahoo/prelude/IndexFactsFactory.java index 97846404852..c2c86c5abb6 100644 --- a/container-search/src/test/java/com/yahoo/prelude/IndexFactsFactory.java +++ b/container-search/src/test/java/com/yahoo/prelude/IndexFactsFactory.java @@ -11,15 +11,9 @@ import com.yahoo.container.QrSearchersConfig; */ public abstract class IndexFactsFactory { - public static IndexFacts newInstance(String configId) { - return new IndexFacts(new IndexModel(resolveConfig(IndexInfoConfig.class, configId), - resolveConfig(QrSearchersConfig.class, configId))); - - } - - public static IndexFacts newInstance(String indexInfoConfigId, String qrSearchersConfigId) { + public static IndexFacts newInstance(String indexInfoConfigId) { return new IndexFacts(new IndexModel(resolveConfig(IndexInfoConfig.class, indexInfoConfigId), - resolveConfig(QrSearchersConfig.class, qrSearchersConfigId))); + resolveConfig(QrSearchersConfig.class, null))); } diff --git a/container-search/src/test/java/com/yahoo/prelude/querytransform/test/CJKSearcherTestCase.java b/container-search/src/test/java/com/yahoo/prelude/querytransform/test/CJKSearcherTestCase.java index 22ba8754572..6c45dd2a1b1 100644 --- a/container-search/src/test/java/com/yahoo/prelude/querytransform/test/CJKSearcherTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/querytransform/test/CJKSearcherTestCase.java @@ -31,7 +31,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; public class CJKSearcherTestCase { private final IndexFacts indexFacts = IndexFactsFactory.newInstance("file:src/test/java/com/yahoo/prelude/" + - "querytransform/test/cjk-index-info.cfg", null); + "querytransform/test/cjk-index-info.cfg"); @Test void testTermWeight() { diff --git a/container-search/src/test/java/com/yahoo/prelude/querytransform/test/StemmingSearcherTestCase.java b/container-search/src/test/java/com/yahoo/prelude/querytransform/test/StemmingSearcherTestCase.java index a925668dffd..3db51ad4b8a 100644 --- a/container-search/src/test/java/com/yahoo/prelude/querytransform/test/StemmingSearcherTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/querytransform/test/StemmingSearcherTestCase.java @@ -29,7 +29,7 @@ public class StemmingSearcherTestCase { private static final Linguistics linguistics = new SimpleLinguistics(); private final IndexFacts indexFacts = IndexFactsFactory.newInstance("dir:src/test/java/com/yahoo/prelude/" + - "querytransform/test/", null); + "querytransform/test/"); @Test void testStemOnlySomeTerms() { |