diff --git a/api/client.gen.go b/api/client.gen.go index 4180a77..2ab2728 100644 --- a/api/client.gen.go +++ b/api/client.gen.go @@ -12,13 +12,13 @@ package api import ( "bytes" - "compress/gzip" "context" "encoding/json" "encoding/xml" "errors" "fmt" "io" + "io/ioutil" "log" "mime/multipart" "net/http" @@ -232,7 +232,7 @@ func (c *APIClient) prepareRequest( fileName string, fileBytes []byte) (localVarRequest *http.Request, err error) { - var body *bytes.Buffer + var body io.ReadCloser // Detect postBody type and post. if postBody != nil { @@ -253,8 +253,8 @@ func (c *APIClient) prepareRequest( if body != nil { return nil, errors.New("cannot specify postBody and multipart form at the same time") } - body = &bytes.Buffer{} - w := multipart.NewWriter(body) + buf := &bytes.Buffer{} + w := multipart.NewWriter(buf) for k, v := range formParams { for _, iv := range v { @@ -285,18 +285,20 @@ func (c *APIClient) prepareRequest( headerParams["Content-Type"] = w.FormDataContentType() // Set Content-Length - headerParams["Content-Length"] = fmt.Sprintf("%d", body.Len()) + headerParams["Content-Length"] = fmt.Sprintf("%d", buf.Len()) w.Close() + body = ioutil.NopCloser(buf) } if strings.HasPrefix(headerParams["Content-Type"], "application/x-www-form-urlencoded") && len(formParams) > 0 { if body != nil { return nil, errors.New("cannot specify postBody and x-www-form-urlencoded form at the same time") } - body = &bytes.Buffer{} - body.WriteString(formParams.Encode()) + buf := &bytes.Buffer{} + buf.WriteString(formParams.Encode()) + body = ioutil.NopCloser(buf) // Set Content-Length - headerParams["Content-Length"] = fmt.Sprintf("%d", body.Len()) + headerParams["Content-Length"] = fmt.Sprintf("%d", buf.Len()) } // Setup path and query parameters @@ -327,18 +329,7 @@ func (c *APIClient) prepareRequest( url.RawQuery = query.Encode() // Generate a new request - if body != nil { - var b io.Reader = body - if enc, ok := headerParams["Content-Encoding"]; ok && enc == "gzip" { - b, err = compressWithGzip(b) - if err != nil { - return nil, err - } - } - localVarRequest, err = http.NewRequest(method, url.String(), b) - } else { - localVarRequest, err = http.NewRequest(method, url.String(), nil) - } + localVarRequest, err = http.NewRequest(method, url.String(), body) if err != nil { return nil, err } @@ -433,16 +424,20 @@ func reportError(format string, a ...interface{}) error { } // Set request body from an interface{} -func setBody(body interface{}, contentType string) (bodyBuf *bytes.Buffer, err error) { - if bodyBuf == nil { - bodyBuf = &bytes.Buffer{} +// NOTE: Assumes that `body` is non-nil. +func setBody(body interface{}, contentType string) (io.ReadCloser, error) { + if rc, ok := body.(io.ReadCloser); ok { + return rc, nil + } else if reader, ok := body.(io.Reader); ok { + return ioutil.NopCloser(reader), nil + } else if fp, ok := body.(**os.File); ok { + return *fp, nil } - if reader, ok := body.(io.Reader); ok { - _, err = bodyBuf.ReadFrom(reader) - } else if fp, ok := body.(**os.File); ok { - _, err = bodyBuf.ReadFrom(*fp) - } else if b, ok := body.([]byte); ok { + var err error + bodyBuf := &bytes.Buffer{} + + if b, ok := body.([]byte); ok { _, err = bodyBuf.Write(b) } else if s, ok := body.(string); ok { _, err = bodyBuf.WriteString(s) @@ -459,24 +454,9 @@ func setBody(body interface{}, contentType string) (bodyBuf *bytes.Buffer, err e } if bodyBuf.Len() == 0 { - err = fmt.Errorf("invalid body type %s", contentType) - return nil, err + return nil, fmt.Errorf("invalid body type %s", contentType) } - return bodyBuf, nil -} - -func compressWithGzip(data io.Reader) (io.Reader, error) { - pr, pw := io.Pipe() - gw := gzip.NewWriter(pw) - var err error - - go func() { - _, err = io.Copy(gw, data) - gw.Close() - pw.Close() - }() - - return pr, err + return ioutil.NopCloser(bodyBuf), nil } // detectContentType method is used to figure out `Request.Body` content type for request header diff --git a/api/client_internal_test.go b/api/client_internal_test.go deleted file mode 100644 index 217c73e..0000000 --- a/api/client_internal_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package api - -import ( - "bytes" - "compress/gzip" - "context" - "io" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestNoGzipRequest(t *testing.T) { - client := APIClient{cfg: NewConfiguration()} - body := []byte("This should not get gzipped") - req, err := client.prepareRequest( - context.Background(), - "/foo", "POST", body, - map[string]string{}, - nil, nil, "", "", nil, - ) - require.NoError(t, err) - defer req.Body.Close() - - out := bytes.Buffer{} - _, err = io.Copy(&out, req.Body) - require.NoError(t, err) - - require.Equal(t, string(body), out.String()) -} - -func TestGzipRequest(t *testing.T) { - client := APIClient{cfg: NewConfiguration()} - body := []byte("This should get gzipped") - req, err := client.prepareRequest( - context.Background(), - "/foo", "POST", body, - map[string]string{"Content-Encoding": "gzip"}, - nil, nil, "", "", nil, - ) - require.NoError(t, err) - defer req.Body.Close() - - out := bytes.Buffer{} - gzr, err := gzip.NewReader(req.Body) - require.NoError(t, err) - defer gzr.Close() - _, err = io.Copy(&out, gzr) - require.NoError(t, err) - - require.Equal(t, string(body), out.String()) -} diff --git a/api/templates/README.md b/api/templates/README.md index d3bd6b6..098facb 100644 --- a/api/templates/README.md +++ b/api/templates/README.md @@ -24,8 +24,8 @@ multiple locations. * Removed use of `golang.org/x/oauth2` to avoid its heavy dependencies * Fixed error strings to be idiomatic according to staticcheck (lowercase, no punctuation) * Use `strings.EqualFold` instead of comparing two `strings.ToLower` calls -* GZip request bodies when `Content-Encoding: gzip` is set * Update the `GenericOpenAPIError` type to enforce that error response models implement the `error` interface +* Update `setBody` to avoid buffering data in memory when the request body is already an `io.ReadCloser` `configuration.mustache` * Deleted `ContextOAuth2` key to match modification in client diff --git a/api/templates/client.mustache b/api/templates/client.mustache index 89ec13d..b5fa081 100644 --- a/api/templates/client.mustache +++ b/api/templates/client.mustache @@ -3,13 +3,13 @@ package {{packageName}} import ( "bytes" - "compress/gzip" "context" "encoding/json" "encoding/xml" "errors" "fmt" "io" + "io/ioutil" "log" "mime/multipart" "net/http" @@ -211,7 +211,7 @@ func (c *APIClient) prepareRequest( fileName string, fileBytes []byte) (localVarRequest *http.Request, err error) { - var body *bytes.Buffer + var body io.ReadCloser // Detect postBody type and post. if postBody != nil { @@ -232,8 +232,8 @@ func (c *APIClient) prepareRequest( if body != nil { return nil, errors.New("cannot specify postBody and multipart form at the same time") } - body = &bytes.Buffer{} - w := multipart.NewWriter(body) + buf := &bytes.Buffer{} + w := multipart.NewWriter(buf) for k, v := range formParams { for _, iv := range v { @@ -264,18 +264,20 @@ func (c *APIClient) prepareRequest( headerParams["Content-Type"] = w.FormDataContentType() // Set Content-Length - headerParams["Content-Length"] = fmt.Sprintf("%d", body.Len()) + headerParams["Content-Length"] = fmt.Sprintf("%d", buf.Len()) w.Close() + body = ioutil.NopCloser(buf) } if strings.HasPrefix(headerParams["Content-Type"], "application/x-www-form-urlencoded") && len(formParams) > 0 { if body != nil { return nil, errors.New("cannot specify postBody and x-www-form-urlencoded form at the same time") } - body = &bytes.Buffer{} - body.WriteString(formParams.Encode()) + buf := &bytes.Buffer{} + buf.WriteString(formParams.Encode()) + body = ioutil.NopCloser(buf) // Set Content-Length - headerParams["Content-Length"] = fmt.Sprintf("%d", body.Len()) + headerParams["Content-Length"] = fmt.Sprintf("%d", buf.Len()) } // Setup path and query parameters @@ -306,18 +308,7 @@ func (c *APIClient) prepareRequest( url.RawQuery = query.Encode() // Generate a new request - if body != nil { - var b io.Reader = body - if enc, ok := headerParams["Content-Encoding"]; ok && enc == "gzip" { - b, err = compressWithGzip(b) - if err != nil { - return nil, err - } - } - localVarRequest, err = http.NewRequest(method, url.String(), b) - } else { - localVarRequest, err = http.NewRequest(method, url.String(), nil) - } + localVarRequest, err = http.NewRequest(method, url.String(), body) if err != nil { return nil, err } @@ -453,16 +444,20 @@ func reportError(format string, a ...interface{}) error { } // Set request body from an interface{} -func setBody(body interface{}, contentType string) (bodyBuf *bytes.Buffer, err error) { - if bodyBuf == nil { - bodyBuf = &bytes.Buffer{} +// NOTE: Assumes that `body` is non-nil. +func setBody(body interface{}, contentType string) (io.ReadCloser, error) { + if rc, ok := body.(io.ReadCloser); ok { + return rc, nil + } else if reader, ok := body.(io.Reader); ok { + return ioutil.NopCloser(reader), nil + } else if fp, ok := body.(**os.File); ok { + return *fp, nil } - if reader, ok := body.(io.Reader); ok { - _, err = bodyBuf.ReadFrom(reader) - } else if fp, ok := body.(**os.File); ok { - _, err = bodyBuf.ReadFrom(*fp) - } else if b, ok := body.([]byte); ok { + var err error + bodyBuf := &bytes.Buffer{} + + if b, ok := body.([]byte); ok { _, err = bodyBuf.Write(b) } else if s, ok := body.(string); ok { _, err = bodyBuf.WriteString(s) @@ -479,24 +474,9 @@ func setBody(body interface{}, contentType string) (bodyBuf *bytes.Buffer, err e } if bodyBuf.Len() == 0 { - err = fmt.Errorf("invalid body type %s", contentType) - return nil, err + return nil, fmt.Errorf("invalid body type %s", contentType) } - return bodyBuf, nil -} - -func compressWithGzip(data io.Reader) (io.Reader, error) { - pr, pw := io.Pipe() - gw := gzip.NewWriter(pw) - var err error - - go func() { - _, err = io.Copy(gw, data) - gw.Close() - pw.Close() - }() - - return pr, err + return ioutil.NopCloser(bodyBuf), nil } // detectContentType method is used to figure out `Request.Body` content type for request header diff --git a/clients/write/write.go b/clients/write/write.go index 280557e..e147432 100644 --- a/clients/write/write.go +++ b/clients/write/write.go @@ -1,6 +1,8 @@ package write import ( + "bytes" + "compress/gzip" "context" "errors" "fmt" @@ -57,7 +59,15 @@ func (c Client) Write(ctx context.Context, params *Params) error { } writeBatch := func(batch []byte) error { - req := c.PostWrite(ctx).Body(batch).ContentEncoding("gzip").Precision(params.Precision) + buf := bytes.Buffer{} + gzw := gzip.NewWriter(&buf) + _, err := gzw.Write(batch) + gzw.Close() + if err != nil { + return err + } + + req := c.PostWrite(ctx).Body(buf.Bytes()).ContentEncoding("gzip").Precision(params.Precision) if params.BucketID != "" { req = req.Bucket(params.BucketID) } else { diff --git a/clients/write/write_test.go b/clients/write/write_test.go index 727430c..dcde167 100644 --- a/clients/write/write_test.go +++ b/clients/write/write_test.go @@ -2,6 +2,7 @@ package write_test import ( "bytes" + "compress/gzip" "context" "io" "io/ioutil" @@ -79,9 +80,16 @@ func TestWriteByIDs(t *testing.T) { return assert.Equal(t, params.OrgID, *in.GetOrg()) && assert.Equal(t, params.BucketID, *in.GetBucket()) && assert.Equal(t, params.Precision, *in.GetPrecision()) && - assert.Equal(t, "gzip", *in.GetContentEncoding()) // Make sure the body is properly marked for compression. + assert.Equal(t, "gzip", *in.GetContentEncoding()) })).DoAndReturn(func(in api.ApiPostWriteRequest) error { - writtenLines = append(writtenLines, string(in.GetBody())) + bodyBytes := bytes.NewReader(in.GetBody()) + gzr, err := gzip.NewReader(bodyBytes) + require.NoError(t, err) + defer gzr.Close() + buf := bytes.Buffer{} + _, err = buf.ReadFrom(gzr) + require.NoError(t, err) + writtenLines = append(writtenLines, buf.String()) return nil }).Times(len(inLines)) @@ -124,9 +132,16 @@ func TestWriteByNames(t *testing.T) { return assert.Equal(t, params.OrgName, *in.GetOrg()) && assert.Equal(t, params.BucketName, *in.GetBucket()) && assert.Equal(t, params.Precision, *in.GetPrecision()) && - assert.Equal(t, "gzip", *in.GetContentEncoding()) // Make sure the body is properly marked for compression. + assert.Equal(t, "gzip", *in.GetContentEncoding()) })).DoAndReturn(func(in api.ApiPostWriteRequest) error { - writtenLines = append(writtenLines, string(in.GetBody())) + bodyBytes := bytes.NewReader(in.GetBody()) + gzr, err := gzip.NewReader(bodyBytes) + require.NoError(t, err) + defer gzr.Close() + buf := bytes.Buffer{} + _, err = buf.ReadFrom(gzr) + require.NoError(t, err) + writtenLines = append(writtenLines, buf.String()) return nil }).Times(len(inLines)) @@ -171,7 +186,14 @@ func TestWriteOrgFromConfig(t *testing.T) { assert.Equal(t, params.Precision, *in.GetPrecision()) && assert.Equal(t, "gzip", *in.GetContentEncoding()) // Make sure the body is properly marked for compression. })).DoAndReturn(func(in api.ApiPostWriteRequest) error { - writtenLines = append(writtenLines, string(in.GetBody())) + bodyBytes := bytes.NewReader(in.GetBody()) + gzr, err := gzip.NewReader(bodyBytes) + require.NoError(t, err) + defer gzr.Close() + buf := bytes.Buffer{} + _, err = buf.ReadFrom(gzr) + require.NoError(t, err) + writtenLines = append(writtenLines, buf.String()) return nil }).Times(len(inLines)) diff --git a/pkg/gzip/pipe.go b/pkg/gzip/pipe.go new file mode 100644 index 0000000..f72290c --- /dev/null +++ b/pkg/gzip/pipe.go @@ -0,0 +1,47 @@ +package gzip + +import ( + "compress/gzip" + "io" +) + +var _ io.ReadCloser = (*gzipPipe)(nil) + +type gzipPipe struct { + underlying io.ReadCloser + pipeOut io.ReadCloser +} + +// NewGzipPipe returns an io.ReadCloser that wraps an input data stream, +// applying gzip compression to the underlying data on Read and closing the +// underlying data on Close. +func NewGzipPipe(in io.ReadCloser) *gzipPipe { + pr, pw := io.Pipe() + gw := gzip.NewWriter(pw) + + go func() { + _, err := io.Copy(gw, in) + gw.Close() + if err != nil { + pw.CloseWithError(err) + } else { + pw.Close() + } + }() + + return &gzipPipe{underlying: in, pipeOut: pr} +} + +func (gzp gzipPipe) Read(p []byte) (int, error) { + return gzp.pipeOut.Read(p) +} + +func (gzp gzipPipe) Close() error { + if err := gzp.pipeOut.Close(); err != nil { + return err + } + if err := gzp.underlying.Close(); err != nil { + return err + } + return nil +} diff --git a/pkg/gzip/pipe_test.go b/pkg/gzip/pipe_test.go new file mode 100644 index 0000000..076779f --- /dev/null +++ b/pkg/gzip/pipe_test.go @@ -0,0 +1,66 @@ +package gzip_test + +import ( + "bytes" + "compress/gzip" + "errors" + "io" + "io/ioutil" + "strings" + "testing" + + pgzip "github.com/influxdata/influx-cli/v2/pkg/gzip" + "github.com/stretchr/testify/require" +) + +func TestGzipPipe(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + data := strings.Repeat("Data data I'm some data\n", 1024) + reader := strings.NewReader(data) + pipe := pgzip.NewGzipPipe(ioutil.NopCloser(reader)) + defer pipe.Close() + gunzip, err := gzip.NewReader(pipe) + require.NoError(t, err) + defer gunzip.Close() + + out := bytes.Buffer{} + _, err = io.Copy(&out, gunzip) + require.NoError(t, err) + + require.Equal(t, data, out.String()) + }) + + t.Run("error", func(t *testing.T) { + t.Parallel() + + reader := &failingReader{n: 3, err: errors.New("I BROKE")} + pipe := pgzip.NewGzipPipe(ioutil.NopCloser(reader)) + defer pipe.Close() + gunzip, err := gzip.NewReader(pipe) + require.NoError(t, err) + defer gunzip.Close() + + out := bytes.Buffer{} + _, err = io.Copy(&out, gunzip) + require.Error(t, err) + require.Equal(t, reader.err, err) + }) +} + +type failingReader struct { + n int + err error +} + +func (frc *failingReader) Read(p []byte) (int, error) { + if frc.n <= 0 { + return 0, frc.err + } + frc.n-- + p[0] = 'a' + return 1, nil +}