From b1d41ebbca435ad4e8a40d154af9c158f5394e79 Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Wed, 19 Oct 2016 11:39:45 +0800 Subject: [PATCH] distsql: fix goroutine leak caused huge memory footprint (#1834) --- distsql/distsql.go | 25 ++++++++++++--- distsql/distsql_test.go | 69 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 5 deletions(-) diff --git a/distsql/distsql.go b/distsql/distsql.go index 81cb477da9..0654861d4c 100644 --- a/distsql/distsql.go +++ b/distsql/distsql.go @@ -74,6 +74,8 @@ type selectResult struct { results chan PartialResult done chan error + + closed chan struct{} } func (r *selectResult) Fetch() { @@ -97,10 +99,16 @@ func (r *selectResult) fetch() { reader: reader, aggregate: r.aggregate, ignoreData: r.ignoreData, - done: make(chan error), + done: make(chan error, 1), } go pr.fetch() - r.results <- pr + + select { + case r.results <- pr: + case <-r.closed: + // if selectResult called Close() already, make fetch goroutine exit + return + } } } @@ -131,6 +139,8 @@ func (r *selectResult) IgnoreData() { // Close closes SelectResult. func (r *selectResult) Close() error { + // close this channel tell fetch goroutine to exit + close(r.closed) return r.resp.Close() } @@ -151,21 +161,27 @@ type partialResult struct { } func (pr *partialResult) fetch() { + defer close(pr.done) pr.resp = new(tipb.SelectResponse) + b, err := ioutil.ReadAll(pr.reader) pr.reader.Close() if err != nil { pr.done <- errors.Trace(err) return } + err = pr.resp.Unmarshal(b) if err != nil { pr.done <- errors.Trace(err) return } + if pr.resp.Error != nil { pr.done <- errInvalidResp.Gen("[%d %s]", pr.resp.Error.GetCode(), pr.resp.Error.GetMsg()) + return } + pr.done <- nil } @@ -175,9 +191,7 @@ var dummyData = make([]types.Datum, 0) // If no more row to return, data would be nil. func (pr *partialResult) Next() (handle int64, data []types.Datum, err error) { if !pr.fetched { - select { - case err = <-pr.done: - } + err = <-pr.done pr.fetched = true if err != nil { return 0, nil, err @@ -289,6 +303,7 @@ func Select(client kv.Client, req *tipb.SelectRequest, keyRanges []kv.KeyRange, resp: resp, results: make(chan PartialResult, 5), done: make(chan error, 1), + closed: make(chan struct{}), } // If Aggregates is not nil, we should set result fields latter. if len(req.Aggregates) == 0 && len(req.GroupBy) == 0 { diff --git a/distsql/distsql_test.go b/distsql/distsql_test.go index 95018ed712..df32a5caa9 100644 --- a/distsql/distsql_test.go +++ b/distsql/distsql_test.go @@ -14,13 +14,20 @@ package distsql import ( + "bytes" + "errors" + "io" + "io/ioutil" + "runtime" "testing" + "time" . "github.com/pingcap/check" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tidb/util/types" + "github.com/pingcap/tipb/go-tipb" ) func TestT(t *testing.T) { @@ -43,3 +50,65 @@ func (s *testTableCodecSuite) TestColumnToProto(c *C) { pc := columnToProto(col) c.Assert(pc.GetFlag(), Equals, int32(10)) } + +// For issue 1791 +func (s *testTableCodecSuite) TestGoroutineLeak(c *C) { + var sr SelectResult + countBefore := runtime.NumGoroutine() + + sr = &selectResult{ + resp: &mockResponse{}, + results: make(chan PartialResult, 5), + done: make(chan error, 1), + closed: make(chan struct{}), + } + go sr.Fetch() + for { + // mock test will generate some partial result then return error + _, err := sr.Next() + if err != nil { + // close selectResult on error, partialResult's fetch goroutine may leak + sr.Close() + break + } + } + + tick := 10 * time.Millisecond + totalSleep := time.Duration(0) + for totalSleep < 3*time.Second { + time.Sleep(tick) + totalSleep += tick + countAfter := runtime.NumGoroutine() + + if countAfter-countBefore < 5 { + return + } + } + + c.Error("distsql goroutine leak!") +} + +type mockResponse struct { + count int +} + +func (resp *mockResponse) Next() (io.ReadCloser, error) { + resp.count++ + if resp.count == 100 { + return nil, errors.New("error happend") + } + return mockReaderCloser(), nil +} + +func (resp *mockResponse) Close() error { + return nil +} + +func mockReaderCloser() io.ReadCloser { + resp := new(tipb.SelectResponse) + b, err := resp.Marshal() + if err != nil { + panic(err) + } + return ioutil.NopCloser(bytes.NewBuffer(b)) +}