From 4780bb478e2055456ea5d6da2fa3df28de3bc997 Mon Sep 17 00:00:00 2001 From: Dylan Wen Date: Fri, 31 Mar 2017 11:13:24 +0800 Subject: [PATCH] store/tikv: use backoffer context for pd client calls (#2953) --- store/tikv/2pc_test.go | 2 +- store/tikv/kv.go | 2 +- store/tikv/lock_test.go | 7 ++++--- store/tikv/oracle/oracle.go | 8 ++++++-- store/tikv/oracle/oracles/local.go | 3 ++- store/tikv/oracle/oracles/local_test.go | 6 ++++-- store/tikv/oracle/oracles/pd.go | 17 +++++++++-------- store/tikv/region_cache.go | 6 +++--- store/tikv/store_test.go | 6 +++--- 9 files changed, 33 insertions(+), 24 deletions(-) diff --git a/store/tikv/2pc_test.go b/store/tikv/2pc_test.go index 4b1ab66214..dab109a5e2 100644 --- a/store/tikv/2pc_test.go +++ b/store/tikv/2pc_test.go @@ -145,7 +145,7 @@ func (s *testCommitterSuite) TestPrewriteRollback(c *C) { err = committer.prewriteKeys(NewBackoffer(prewriteMaxBackoff, ctx), committer.keys) c.Assert(err, IsNil) } - committer.commitTS, err = s.store.oracle.GetTimestamp() + committer.commitTS, err = s.store.oracle.GetTimestamp(ctx) c.Assert(err, IsNil) err = committer.commitKeys(NewBackoffer(commitMaxBackoff, ctx), [][]byte{[]byte("a")}) c.Assert(err, IsNil) diff --git a/store/tikv/kv.go b/store/tikv/kv.go index 95e565e140..f8788c115e 100644 --- a/store/tikv/kv.go +++ b/store/tikv/kv.go @@ -196,7 +196,7 @@ func (s *tikvStore) CurrentVersion() (kv.Version, error) { func (s *tikvStore) getTimestampWithRetry(bo *Backoffer) (uint64, error) { for { - startTS, err := s.oracle.GetTimestamp() + startTS, err := s.oracle.GetTimestamp(bo.ctx) if err == nil { return startTS, nil } diff --git a/store/tikv/lock_test.go b/store/tikv/lock_test.go index 9a6610f21e..a4118bf0f4 100644 --- a/store/tikv/lock_test.go +++ b/store/tikv/lock_test.go @@ -56,13 +56,14 @@ func (s *testLockSuite) lockKey(c *C, key, value, primaryKey, primaryValue []byt c.Assert(err, IsNil) tpc.keys = [][]byte{primaryKey, key} - err = tpc.prewriteKeys(NewBackoffer(prewriteMaxBackoff, goctx.Background()), tpc.keys) + ctx := goctx.Background() + err = tpc.prewriteKeys(NewBackoffer(prewriteMaxBackoff, ctx), tpc.keys) c.Assert(err, IsNil) if commitPrimary { - tpc.commitTS, err = s.store.oracle.GetTimestamp() + tpc.commitTS, err = s.store.oracle.GetTimestamp(ctx) c.Assert(err, IsNil) - err = tpc.commitKeys(NewBackoffer(commitMaxBackoff, goctx.Background()), [][]byte{primaryKey}) + err = tpc.commitKeys(NewBackoffer(commitMaxBackoff, ctx), [][]byte{primaryKey}) c.Assert(err, IsNil) } return txn.startTS, tpc.commitTS diff --git a/store/tikv/oracle/oracle.go b/store/tikv/oracle/oracle.go index c52962c042..10ed5f668d 100644 --- a/store/tikv/oracle/oracle.go +++ b/store/tikv/oracle/oracle.go @@ -13,11 +13,15 @@ package oracle -import "time" +import ( + "time" + + "golang.org/x/net/context" +) // Oracle is the interface that provides strictly ascending timestamps. type Oracle interface { - GetTimestamp() (uint64, error) + GetTimestamp(ctx context.Context) (uint64, error) IsExpired(lockTimestamp uint64, TTL uint64) bool Close() } diff --git a/store/tikv/oracle/oracles/local.go b/store/tikv/oracle/oracles/local.go index 05075d2671..a89b633ea4 100644 --- a/store/tikv/oracle/oracles/local.go +++ b/store/tikv/oracle/oracles/local.go @@ -18,6 +18,7 @@ import ( "time" "github.com/pingcap/tidb/store/tikv/oracle" + "golang.org/x/net/context" ) var _ oracle.Oracle = &localOracle{} @@ -37,7 +38,7 @@ func (l *localOracle) IsExpired(lockTS uint64, TTL uint64) bool { return oracle.GetPhysical(time.Now()) >= oracle.ExtractPhysical(lockTS)+int64(TTL) } -func (l *localOracle) GetTimestamp() (uint64, error) { +func (l *localOracle) GetTimestamp(context.Context) (uint64, error) { l.Lock() defer l.Unlock() physical := oracle.GetPhysical(time.Now()) diff --git a/store/tikv/oracle/oracles/local_test.go b/store/tikv/oracle/oracles/local_test.go index 9c77f013b1..7bf7659d30 100644 --- a/store/tikv/oracle/oracles/local_test.go +++ b/store/tikv/oracle/oracles/local_test.go @@ -16,6 +16,8 @@ package oracles import ( "testing" "time" + + "golang.org/x/net/context" ) func TestLocalOracle(t *testing.T) { @@ -23,7 +25,7 @@ func TestLocalOracle(t *testing.T) { defer l.Close() m := map[uint64]struct{}{} for i := 0; i < 100000; i++ { - ts, err := l.GetTimestamp() + ts, err := l.GetTimestamp(context.Background()) if err != nil { t.Error(err) } @@ -38,7 +40,7 @@ func TestLocalOracle(t *testing.T) { func TestIsExpired(t *testing.T) { o := NewLocalOracle() defer o.Close() - ts, _ := o.GetTimestamp() + ts, _ := o.GetTimestamp(context.Background()) time.Sleep(1 * time.Second) expire := o.IsExpired(uint64(ts), 500) if !expire { diff --git a/store/tikv/oracle/oracles/pd.go b/store/tikv/oracle/oracles/pd.go index a925a6640b..5a4721482c 100644 --- a/store/tikv/oracle/oracles/pd.go +++ b/store/tikv/oracle/oracles/pd.go @@ -45,9 +45,10 @@ func NewPdOracle(pdClient pd.Client, updateInterval time.Duration) (oracle.Oracl c: pdClient, quit: make(chan struct{}), } - go o.updateTS(updateInterval) + ctx := context.TODO() + go o.updateTS(ctx, updateInterval) // Initialize lastTS by Get. - _, err := o.GetTimestamp() + _, err := o.GetTimestamp(ctx) if err != nil { o.Close() return nil, errors.Trace(err) @@ -63,8 +64,8 @@ func (o *pdOracle) IsExpired(lockTS, TTL uint64) bool { } // GetTimestamp gets a new increasing time. -func (o *pdOracle) GetTimestamp() (uint64, error) { - ts, err := o.getTimestamp() +func (o *pdOracle) GetTimestamp(ctx context.Context) (uint64, error) { + ts, err := o.getTimestamp(ctx) if err != nil { return 0, errors.Trace(err) } @@ -72,9 +73,9 @@ func (o *pdOracle) GetTimestamp() (uint64, error) { return ts, nil } -func (o *pdOracle) getTimestamp() (uint64, error) { +func (o *pdOracle) getTimestamp(ctx context.Context) (uint64, error) { now := time.Now() - physical, logical, err := o.c.GetTS(context.TODO()) + physical, logical, err := o.c.GetTS(ctx) if err != nil { return 0, errors.Trace(err) } @@ -92,12 +93,12 @@ func (o *pdOracle) setLastTS(ts uint64) { } } -func (o *pdOracle) updateTS(interval time.Duration) { +func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) for { select { case <-ticker.C: - ts, err := o.getTimestamp() + ts, err := o.getTimestamp(ctx) if err != nil { log.Errorf("updateTS error: %v", err) break diff --git a/store/tikv/region_cache.go b/store/tikv/region_cache.go index c10c2f7a02..b01edeec80 100644 --- a/store/tikv/region_cache.go +++ b/store/tikv/region_cache.go @@ -289,7 +289,7 @@ func (c *RegionCache) loadRegion(bo *Backoffer, key []byte) (*Region, error) { } } - meta, leader, err := c.pdClient.GetRegion(goctx.TODO(), key) + meta, leader, err := c.pdClient.GetRegion(bo.ctx, key) if err != nil { backoffErr = errors.Errorf("loadRegion from PD failed, key: %q, err: %v", key, err) continue @@ -323,7 +323,7 @@ func (c *RegionCache) loadRegionByID(bo *Backoffer, regionID uint64) (*Region, e } } - meta, leader, err := c.pdClient.GetRegionByID(goctx.TODO(), regionID) + meta, leader, err := c.pdClient.GetRegionByID(bo.ctx, regionID) if err != nil { backoffErr = errors.Errorf("loadRegion from PD failed, regionID: %v, err: %v", regionID, err) continue @@ -383,7 +383,7 @@ func (c *RegionCache) ClearStoreByID(id uint64) { func (c *RegionCache) loadStoreAddr(bo *Backoffer, id uint64) (string, error) { for { - store, err := c.pdClient.GetStore(goctx.TODO(), id) + store, err := c.pdClient.GetStore(bo.ctx, id) if err != nil { err = errors.Errorf("loadStore from PD failed, id: %d, err: %v", id, err) if err = bo.Backoff(boPDRPC, err); err != nil { diff --git a/store/tikv/store_test.go b/store/tikv/store_test.go index 44f07e0a66..5669e53341 100644 --- a/store/tikv/store_test.go +++ b/store/tikv/store_test.go @@ -23,9 +23,9 @@ import ( "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/pd/pd-client" + pd "github.com/pingcap/pd/pd-client" "github.com/pingcap/tidb" - "github.com/pingcap/tidb/store/tikv/mock-tikv" + mocktikv "github.com/pingcap/tidb/store/tikv/mock-tikv" "github.com/pingcap/tidb/store/tikv/oracle" goctx "golang.org/x/net/context" ) @@ -195,7 +195,7 @@ func (o *mockOracle) addOffset(d time.Duration) { o.offset += d } -func (o *mockOracle) GetTimestamp() (uint64, error) { +func (o *mockOracle) GetTimestamp(goctx.Context) (uint64, error) { o.Lock() defer o.Unlock()