From 311b8b8c7c220840f4277709a8f2c74943a6e7eb Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Tue, 9 May 2023 19:53:01 +0200 Subject: Stream request body --- client/go/internal/cli/cmd/feed_test.go | 1 + client/go/internal/vespa/document/http.go | 60 +++++++++++++++++--------- client/go/internal/vespa/document/http_test.go | 60 +++++++++++--------------- 3 files changed, 66 insertions(+), 55 deletions(-) diff --git a/client/go/internal/cli/cmd/feed_test.go b/client/go/internal/cli/cmd/feed_test.go index 097d4ae5fa3..bd0b9544e37 100644 --- a/client/go/internal/cli/cmd/feed_test.go +++ b/client/go/internal/cli/cmd/feed_test.go @@ -27,6 +27,7 @@ func TestFeed(t *testing.T) { clock := &manualClock{tick: time.Second} cli, stdout, stderr := newTestCLI(t) httpClient := cli.httpClient.(*mock.HTTPClient) + httpClient.ReadBody = true cli.now = clock.now td := t.TempDir() diff --git a/client/go/internal/vespa/document/http.go b/client/go/internal/vespa/document/http.go index cf13d72c26b..74cc46ec962 100644 --- a/client/go/internal/vespa/document/http.go +++ b/client/go/internal/vespa/document/http.go @@ -1,6 +1,7 @@ package document import ( + "bufio" "bytes" "encoding/json" "fmt" @@ -53,6 +54,19 @@ type ClientOptions struct { NowFunc func() time.Time } +type countingReader struct { + reader io.ReadCloser + bytesRead int64 +} + +func (r *countingReader) Read(p []byte) (int, error) { + n, err := r.reader.Read(p) + r.bytesRead += int64(n) + return n, err +} + +func (r *countingReader) Close() error { return r.reader.Close() } + type countingHTTPClient struct { client util.HTTPClient inflight int64 @@ -200,34 +214,36 @@ func (c *Client) buffer() *bytes.Buffer { return buf } -func (c *Client) createRequest(method, url string, body []byte) (*http.Request, error) { +func (c *Client) createRequest(method, url string, body []byte) (*http.Request, *countingReader, error) { if len(body) == 0 { - return http.NewRequest(method, url, nil) + req, err := http.NewRequest(method, url, nil) + return req, nil, err } - var buf *bytes.Buffer useGzip := c.options.Compression == CompressionGzip || (c.options.Compression == CompressionAuto && len(body) > 512) - if useGzip { - buf = bytes.NewBuffer(make([]byte, 0, 1024)) - w := c.gzipWriter(buf) - writeRequestBody(w, body) - if err := w.Close(); err != nil { - return nil, err + r, w := io.Pipe() + go func() { + defer w.Close() + if useGzip { + buf := bufio.NewWriterSize(w, 1024) + zw := c.gzipWriter(buf) + writeRequestBody(zw, body) + zw.Close() + c.gzippers.Put(zw) + buf.Flush() + } else { + writeRequestBody(w, body) } - c.gzippers.Put(w) - } else { - buf = bytes.NewBuffer(make([]byte, 0, len(fieldsPrefix)+len(body)+len(fieldsSuffix))) - writeRequestBody(buf, body) - } - req, err := http.NewRequest(method, url, buf) + }() + cr := &countingReader{reader: r} + req, err := http.NewRequest(method, url, cr) if err != nil { - return nil, err + return nil, cr, err } if useGzip { req.Header.Set("Content-Encoding", "gzip") } req.Header.Set("Content-Type", "application/json; charset=utf-8") - req.ContentLength = int64(buf.Len()) - return req, nil + return req, cr, nil } func (c *Client) clientTimeout() time.Duration { @@ -242,7 +258,7 @@ func (c *Client) Send(document Document) Result { start := c.now() result := Result{Id: document.Id, Stats: Stats{Requests: 1}} method, url := c.methodAndURL(document) - req, err := c.createRequest(method, url, document.Fields) + req, cr, err := c.createRequest(method, url, document.Fields) if err != nil { return resultWithErr(result, err) } @@ -252,7 +268,11 @@ func (c *Client) Send(document Document) Result { } defer resp.Body.Close() elapsed := c.now().Sub(start) - return c.resultWithResponse(resp, req.ContentLength, result, elapsed) + var bytesRead int64 + if cr != nil { + bytesRead = cr.bytesRead + } + return c.resultWithResponse(resp, bytesRead, result, elapsed) } func resultWithErr(result Result, err error) Result { diff --git a/client/go/internal/vespa/document/http_test.go b/client/go/internal/vespa/document/http_test.go index a582cd7ec6e..9a47b4f45fe 100644 --- a/client/go/internal/vespa/document/http_test.go +++ b/client/go/internal/vespa/document/http_test.go @@ -3,7 +3,6 @@ package document import ( "bytes" "fmt" - "io" "net/http" "reflect" "strings" @@ -63,7 +62,7 @@ func TestClientSend(t *testing.T) { {Create: true, Id: mustParseId("id:ns:type::doc2"), Operation: OperationUpdate, Fields: []byte(`{"foo": "456"}`)}, {Create: true, Id: mustParseId("id:ns:type::doc3"), Operation: OperationUpdate, Fields: []byte(`{"baz": "789"}`)}, } - httpClient := mock.HTTPClient{} + httpClient := mock.HTTPClient{ReadBody: true} client, _ := NewClient(ClientOptions{ BaseURL: "https://example.com:1337", Timeout: time.Duration(5 * time.Second), @@ -80,7 +79,6 @@ func TestClientSend(t *testing.T) { TotalLatency: time.Second, MinLatency: time.Second, MaxLatency: time.Second, - BytesSent: 25, }, } if i < 2 { @@ -100,6 +98,7 @@ func TestClientSend(t *testing.T) { wantRes.Stats.BytesRecv = 36 } res := client.Send(doc) + wantRes.Stats.BytesSent = int64(len(httpClient.LastBody)) if !reflect.DeepEqual(res, wantRes) { t.Fatalf("got result %+v, want %+v", res, wantRes) } @@ -112,19 +111,12 @@ func TestClientSend(t *testing.T) { if r.URL.String() != wantURL { t.Errorf("got r.URL = %q, want %q", r.URL, wantURL) } - body, err := io.ReadAll(r.Body) - if err != nil { - t.Fatalf("got unexpected error %q", err) - } var wantBody bytes.Buffer wantBody.WriteString(`{"fields":`) wantBody.Write(doc.Fields) wantBody.WriteString("}") - if !bytes.Equal(body, wantBody.Bytes()) { - t.Errorf("got r.Body = %q, want %q", string(body), wantBody.String()) - } - if r.ContentLength != int64(len(body)) { - t.Errorf("got r.ContentLength=%d, want %d", r.ContentLength, len(body)) + if !bytes.Equal(httpClient.LastBody, wantBody.Bytes()) { + t.Errorf("got r.Body = %q, want %q", string(httpClient.LastBody), wantBody.String()) } } want := Stats{ @@ -148,52 +140,50 @@ func TestClientSend(t *testing.T) { } func TestClientSendCompressed(t *testing.T) { - httpClient := mock.HTTPClient{} + httpClient := &mock.HTTPClient{ReadBody: true} client, _ := NewClient(ClientOptions{ BaseURL: "https://example.com:1337", Timeout: time.Duration(5 * time.Second), - }, []util.HTTPClient{&httpClient}) + }, []util.HTTPClient{httpClient}) bigBody := fmt.Sprintf(`{"foo": "%s"}`, strings.Repeat("s", 512+1)) bigDoc := Document{Create: true, Id: mustParseId("id:ns:type::doc1"), Operation: OperationUpdate, Fields: []byte(bigBody)} smallDoc := Document{Create: true, Id: mustParseId("id:ns:type::doc2"), Operation: OperationUpdate, Fields: []byte(`{"foo": "s"}`)} + var result Result client.options.Compression = CompressionNone - _ = client.Send(bigDoc) - assertCompressedRequest(t, false, httpClient.LastRequest) - _ = client.Send(smallDoc) - assertCompressedRequest(t, false, httpClient.LastRequest) + result = client.Send(bigDoc) + assertCompressedRequest(t, false, result, httpClient) + result = client.Send(smallDoc) + assertCompressedRequest(t, false, result, httpClient) client.options.Compression = CompressionAuto - _ = client.Send(bigDoc) - assertCompressedRequest(t, true, httpClient.LastRequest) - _ = client.Send(smallDoc) - assertCompressedRequest(t, false, httpClient.LastRequest) + result = client.Send(bigDoc) + assertCompressedRequest(t, true, result, httpClient) + result = client.Send(smallDoc) + assertCompressedRequest(t, false, result, httpClient) client.options.Compression = CompressionGzip - _ = client.Send(bigDoc) - assertCompressedRequest(t, true, httpClient.LastRequest) - _ = client.Send(smallDoc) - assertCompressedRequest(t, true, httpClient.LastRequest) + result = client.Send(bigDoc) + assertCompressedRequest(t, true, result, httpClient) + result = client.Send(smallDoc) + assertCompressedRequest(t, true, result, httpClient) } -func assertCompressedRequest(t *testing.T, want bool, request *http.Request) { +func assertCompressedRequest(t *testing.T, want bool, result Result, client *mock.HTTPClient) { + t.Helper() wantEnc := "" if want { wantEnc = "gzip" } - gotEnc := request.Header.Get("Content-Encoding") + gotEnc := client.LastRequest.Header.Get("Content-Encoding") if gotEnc != wantEnc { t.Errorf("got Content-Encoding=%q, want %q", gotEnc, wantEnc) } - body, err := io.ReadAll(request.Body) - if err != nil { - t.Fatal(err) - } - if request.ContentLength != int64(len(body)) { - t.Errorf("got ContentLength=%d, want %d", request.ContentLength, len(body)) + if result.Stats.BytesSent != int64(len(client.LastBody)) { + t.Errorf("got BytesSent=%d, want %d", result.Stats.BytesSent, len(client.LastBody)) } - compressed := bytes.HasPrefix(body, []byte{0x1f, 0x8b}) + compressed := bytes.HasPrefix(client.LastBody, []byte{0x1f, 0x8b}) if compressed != want { t.Errorf("got compressed=%t, want %t", compressed, want) } -- cgit v1.2.3