store/tikv: use backoffer context for pd client calls (#2953)
This commit is contained in:
@ -145,7 +145,7 @@ func (s *testCommitterSuite) TestPrewriteRollback(c *C) {
|
||||
err = committer.prewriteKeys(NewBackoffer(prewriteMaxBackoff, ctx), committer.keys)
|
||||
c.Assert(err, IsNil)
|
||||
}
|
||||
committer.commitTS, err = s.store.oracle.GetTimestamp()
|
||||
committer.commitTS, err = s.store.oracle.GetTimestamp(ctx)
|
||||
c.Assert(err, IsNil)
|
||||
err = committer.commitKeys(NewBackoffer(commitMaxBackoff, ctx), [][]byte{[]byte("a")})
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
@ -196,7 +196,7 @@ func (s *tikvStore) CurrentVersion() (kv.Version, error) {
|
||||
|
||||
func (s *tikvStore) getTimestampWithRetry(bo *Backoffer) (uint64, error) {
|
||||
for {
|
||||
startTS, err := s.oracle.GetTimestamp()
|
||||
startTS, err := s.oracle.GetTimestamp(bo.ctx)
|
||||
if err == nil {
|
||||
return startTS, nil
|
||||
}
|
||||
|
||||
@ -56,13 +56,14 @@ func (s *testLockSuite) lockKey(c *C, key, value, primaryKey, primaryValue []byt
|
||||
c.Assert(err, IsNil)
|
||||
tpc.keys = [][]byte{primaryKey, key}
|
||||
|
||||
err = tpc.prewriteKeys(NewBackoffer(prewriteMaxBackoff, goctx.Background()), tpc.keys)
|
||||
ctx := goctx.Background()
|
||||
err = tpc.prewriteKeys(NewBackoffer(prewriteMaxBackoff, ctx), tpc.keys)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
if commitPrimary {
|
||||
tpc.commitTS, err = s.store.oracle.GetTimestamp()
|
||||
tpc.commitTS, err = s.store.oracle.GetTimestamp(ctx)
|
||||
c.Assert(err, IsNil)
|
||||
err = tpc.commitKeys(NewBackoffer(commitMaxBackoff, goctx.Background()), [][]byte{primaryKey})
|
||||
err = tpc.commitKeys(NewBackoffer(commitMaxBackoff, ctx), [][]byte{primaryKey})
|
||||
c.Assert(err, IsNil)
|
||||
}
|
||||
return txn.startTS, tpc.commitTS
|
||||
|
||||
@ -13,11 +13,15 @@
|
||||
|
||||
package oracle
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Oracle is the interface that provides strictly ascending timestamps.
|
||||
type Oracle interface {
|
||||
GetTimestamp() (uint64, error)
|
||||
GetTimestamp(ctx context.Context) (uint64, error)
|
||||
IsExpired(lockTimestamp uint64, TTL uint64) bool
|
||||
Close()
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pingcap/tidb/store/tikv/oracle"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var _ oracle.Oracle = &localOracle{}
|
||||
@ -37,7 +38,7 @@ func (l *localOracle) IsExpired(lockTS uint64, TTL uint64) bool {
|
||||
return oracle.GetPhysical(time.Now()) >= oracle.ExtractPhysical(lockTS)+int64(TTL)
|
||||
}
|
||||
|
||||
func (l *localOracle) GetTimestamp() (uint64, error) {
|
||||
func (l *localOracle) GetTimestamp(context.Context) (uint64, error) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
physical := oracle.GetPhysical(time.Now())
|
||||
|
||||
@ -16,6 +16,8 @@ package oracles
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestLocalOracle(t *testing.T) {
|
||||
@ -23,7 +25,7 @@ func TestLocalOracle(t *testing.T) {
|
||||
defer l.Close()
|
||||
m := map[uint64]struct{}{}
|
||||
for i := 0; i < 100000; i++ {
|
||||
ts, err := l.GetTimestamp()
|
||||
ts, err := l.GetTimestamp(context.Background())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -38,7 +40,7 @@ func TestLocalOracle(t *testing.T) {
|
||||
func TestIsExpired(t *testing.T) {
|
||||
o := NewLocalOracle()
|
||||
defer o.Close()
|
||||
ts, _ := o.GetTimestamp()
|
||||
ts, _ := o.GetTimestamp(context.Background())
|
||||
time.Sleep(1 * time.Second)
|
||||
expire := o.IsExpired(uint64(ts), 500)
|
||||
if !expire {
|
||||
|
||||
@ -45,9 +45,10 @@ func NewPdOracle(pdClient pd.Client, updateInterval time.Duration) (oracle.Oracl
|
||||
c: pdClient,
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
go o.updateTS(updateInterval)
|
||||
ctx := context.TODO()
|
||||
go o.updateTS(ctx, updateInterval)
|
||||
// Initialize lastTS by Get.
|
||||
_, err := o.GetTimestamp()
|
||||
_, err := o.GetTimestamp(ctx)
|
||||
if err != nil {
|
||||
o.Close()
|
||||
return nil, errors.Trace(err)
|
||||
@ -63,8 +64,8 @@ func (o *pdOracle) IsExpired(lockTS, TTL uint64) bool {
|
||||
}
|
||||
|
||||
// GetTimestamp gets a new increasing time.
|
||||
func (o *pdOracle) GetTimestamp() (uint64, error) {
|
||||
ts, err := o.getTimestamp()
|
||||
func (o *pdOracle) GetTimestamp(ctx context.Context) (uint64, error) {
|
||||
ts, err := o.getTimestamp(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
@ -72,9 +73,9 @@ func (o *pdOracle) GetTimestamp() (uint64, error) {
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func (o *pdOracle) getTimestamp() (uint64, error) {
|
||||
func (o *pdOracle) getTimestamp(ctx context.Context) (uint64, error) {
|
||||
now := time.Now()
|
||||
physical, logical, err := o.c.GetTS(context.TODO())
|
||||
physical, logical, err := o.c.GetTS(ctx)
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
@ -92,12 +93,12 @@ func (o *pdOracle) setLastTS(ts uint64) {
|
||||
}
|
||||
}
|
||||
|
||||
func (o *pdOracle) updateTS(interval time.Duration) {
|
||||
func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ts, err := o.getTimestamp()
|
||||
ts, err := o.getTimestamp(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("updateTS error: %v", err)
|
||||
break
|
||||
|
||||
@ -289,7 +289,7 @@ func (c *RegionCache) loadRegion(bo *Backoffer, key []byte) (*Region, error) {
|
||||
}
|
||||
}
|
||||
|
||||
meta, leader, err := c.pdClient.GetRegion(goctx.TODO(), key)
|
||||
meta, leader, err := c.pdClient.GetRegion(bo.ctx, key)
|
||||
if err != nil {
|
||||
backoffErr = errors.Errorf("loadRegion from PD failed, key: %q, err: %v", key, err)
|
||||
continue
|
||||
@ -323,7 +323,7 @@ func (c *RegionCache) loadRegionByID(bo *Backoffer, regionID uint64) (*Region, e
|
||||
}
|
||||
}
|
||||
|
||||
meta, leader, err := c.pdClient.GetRegionByID(goctx.TODO(), regionID)
|
||||
meta, leader, err := c.pdClient.GetRegionByID(bo.ctx, regionID)
|
||||
if err != nil {
|
||||
backoffErr = errors.Errorf("loadRegion from PD failed, regionID: %v, err: %v", regionID, err)
|
||||
continue
|
||||
@ -383,7 +383,7 @@ func (c *RegionCache) ClearStoreByID(id uint64) {
|
||||
|
||||
func (c *RegionCache) loadStoreAddr(bo *Backoffer, id uint64) (string, error) {
|
||||
for {
|
||||
store, err := c.pdClient.GetStore(goctx.TODO(), id)
|
||||
store, err := c.pdClient.GetStore(bo.ctx, id)
|
||||
if err != nil {
|
||||
err = errors.Errorf("loadStore from PD failed, id: %d, err: %v", id, err)
|
||||
if err = bo.Backoff(boPDRPC, err); err != nil {
|
||||
|
||||
@ -23,9 +23,9 @@ import (
|
||||
"github.com/pingcap/kvproto/pkg/errorpb"
|
||||
"github.com/pingcap/kvproto/pkg/kvrpcpb"
|
||||
"github.com/pingcap/kvproto/pkg/metapb"
|
||||
"github.com/pingcap/pd/pd-client"
|
||||
pd "github.com/pingcap/pd/pd-client"
|
||||
"github.com/pingcap/tidb"
|
||||
"github.com/pingcap/tidb/store/tikv/mock-tikv"
|
||||
mocktikv "github.com/pingcap/tidb/store/tikv/mock-tikv"
|
||||
"github.com/pingcap/tidb/store/tikv/oracle"
|
||||
goctx "golang.org/x/net/context"
|
||||
)
|
||||
@ -195,7 +195,7 @@ func (o *mockOracle) addOffset(d time.Duration) {
|
||||
o.offset += d
|
||||
}
|
||||
|
||||
func (o *mockOracle) GetTimestamp() (uint64, error) {
|
||||
func (o *mockOracle) GetTimestamp(goctx.Context) (uint64, error) {
|
||||
o.Lock()
|
||||
defer o.Unlock()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user