mocktikv: split rpcHandler to kvHandler and coprHandler (#22857)

This commit is contained in:
disksing
2021-03-26 19:21:23 +08:00
committed by GitHub
parent 2df2ca06d2
commit 7a35af8a4c
5 changed files with 272 additions and 243 deletions

View File

@ -33,7 +33,7 @@ import (
"github.com/pingcap/tipb/go-tipb"
)
func (h *rpcHandler) handleCopAnalyzeRequest(req *coprocessor.Request) *coprocessor.Response {
func (h coprHandler) handleCopAnalyzeRequest(req *coprocessor.Request) *coprocessor.Response {
resp := &coprocessor.Response{}
if len(req.Ranges) == 0 {
return resp
@ -62,7 +62,7 @@ func (h *rpcHandler) handleCopAnalyzeRequest(req *coprocessor.Request) *coproces
return resp
}
func (h *rpcHandler) handleAnalyzeIndexReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (*coprocessor.Response, error) {
func (h coprHandler) handleAnalyzeIndexReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (*coprocessor.Response, error) {
ranges, err := h.extractKVRanges(req.Ranges, false)
if err != nil {
return nil, errors.Trace(err)
@ -125,7 +125,7 @@ type analyzeColumnsExec struct {
fields []*ast.ResultField
}
func (h *rpcHandler) handleAnalyzeColumnsReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (_ *coprocessor.Response, err error) {
func (h coprHandler) handleAnalyzeColumnsReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (_ *coprocessor.Response, err error) {
sc := flagsToStatementContext(analyzeReq.Flags)
sc.TimeZone, err = constructTimeZone("", int(analyzeReq.TimeZoneOffset))
if err != nil {

View File

@ -20,7 +20,7 @@ import (
"github.com/pingcap/tipb/go-tipb"
)
func (h *rpcHandler) handleCopChecksumRequest(req *coprocessor.Request) *coprocessor.Response {
func (h coprHandler) handleCopChecksumRequest(req *coprocessor.Request) *coprocessor.Response {
resp := &tipb.ChecksumResponse{
Checksum: 1,
TotalKvs: 1,

View File

@ -54,7 +54,7 @@ type dagContext struct {
evalCtx *evalContext
}
func (h *rpcHandler) handleCopDAGRequest(req *coprocessor.Request) *coprocessor.Response {
func (h coprHandler) handleCopDAGRequest(req *coprocessor.Request) *coprocessor.Response {
resp := &coprocessor.Response{}
dagCtx, e, dagReq, err := h.buildDAGExecutor(req)
if err != nil {
@ -88,7 +88,7 @@ func (h *rpcHandler) handleCopDAGRequest(req *coprocessor.Request) *coprocessor.
return buildResp(selResp, execDetails, err)
}
func (h *rpcHandler) buildDAGExecutor(req *coprocessor.Request) (*dagContext, executor, *tipb.DAGRequest, error) {
func (h coprHandler) buildDAGExecutor(req *coprocessor.Request) (*dagContext, executor, *tipb.DAGRequest, error) {
if len(req.Ranges) == 0 {
return nil, nil, nil, errors.New("request range is null")
}
@ -133,7 +133,7 @@ func constructTimeZone(name string, offset int) (*time.Location, error) {
return timeutil.ConstructTimeZone(name, offset)
}
func (h *rpcHandler) handleCopStream(ctx context.Context, req *coprocessor.Request) (tikvpb.Tikv_CoprocessorStreamClient, error) {
func (h coprHandler) handleCopStream(ctx context.Context, req *coprocessor.Request) (tikvpb.Tikv_CoprocessorStreamClient, error) {
dagCtx, e, dagReq, err := h.buildDAGExecutor(req)
if err != nil {
return nil, errors.Trace(err)
@ -147,7 +147,7 @@ func (h *rpcHandler) handleCopStream(ctx context.Context, req *coprocessor.Reque
}, nil
}
func (h *rpcHandler) buildExec(ctx *dagContext, curr *tipb.Executor) (executor, *tipb.Executor, error) {
func (h coprHandler) buildExec(ctx *dagContext, curr *tipb.Executor) (executor, *tipb.Executor, error) {
var currExec executor
var err error
var childExec *tipb.Executor
@ -179,7 +179,7 @@ func (h *rpcHandler) buildExec(ctx *dagContext, curr *tipb.Executor) (executor,
return currExec, childExec, errors.Trace(err)
}
func (h *rpcHandler) buildDAGForTiFlash(ctx *dagContext, farther *tipb.Executor) (executor, error) {
func (h coprHandler) buildDAGForTiFlash(ctx *dagContext, farther *tipb.Executor) (executor, error) {
curr, child, err := h.buildExec(ctx, farther)
if err != nil {
return nil, errors.Trace(err)
@ -194,7 +194,7 @@ func (h *rpcHandler) buildDAGForTiFlash(ctx *dagContext, farther *tipb.Executor)
return curr, nil
}
func (h *rpcHandler) buildDAG(ctx *dagContext, executors []*tipb.Executor) (executor, error) {
func (h coprHandler) buildDAG(ctx *dagContext, executors []*tipb.Executor) (executor, error) {
var src executor
for i := 0; i < len(executors); i++ {
curr, _, err := h.buildExec(ctx, executors[i])
@ -207,7 +207,7 @@ func (h *rpcHandler) buildDAG(ctx *dagContext, executors []*tipb.Executor) (exec
return src, nil
}
func (h *rpcHandler) buildTableScan(ctx *dagContext, executor *tipb.Executor) (*tableScanExec, error) {
func (h coprHandler) buildTableScan(ctx *dagContext, executor *tipb.Executor) (*tableScanExec, error) {
columns := executor.TblScan.Columns
ctx.evalCtx.setColumnInfo(columns)
ranges, err := h.extractKVRanges(ctx.keyRanges, executor.TblScan.Desc)
@ -258,7 +258,7 @@ func (h *rpcHandler) buildTableScan(ctx *dagContext, executor *tipb.Executor) (*
return e, nil
}
func (h *rpcHandler) buildIndexScan(ctx *dagContext, executor *tipb.Executor) (*indexScanExec, error) {
func (h coprHandler) buildIndexScan(ctx *dagContext, executor *tipb.Executor) (*indexScanExec, error) {
var err error
columns := executor.IdxScan.Columns
ctx.evalCtx.setColumnInfo(columns)
@ -311,7 +311,7 @@ func (h *rpcHandler) buildIndexScan(ctx *dagContext, executor *tipb.Executor) (*
return e, nil
}
func (h *rpcHandler) buildSelection(ctx *dagContext, executor *tipb.Executor) (*selectionExec, error) {
func (h coprHandler) buildSelection(ctx *dagContext, executor *tipb.Executor) (*selectionExec, error) {
var err error
var relatedColOffsets []int
pbConds := executor.Selection.Conditions
@ -335,7 +335,7 @@ func (h *rpcHandler) buildSelection(ctx *dagContext, executor *tipb.Executor) (*
}, nil
}
func (h *rpcHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]aggregation.Aggregation, []expression.Expression, []int, error) {
func (h coprHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]aggregation.Aggregation, []expression.Expression, []int, error) {
length := len(executor.Aggregation.AggFunc)
aggs := make([]aggregation.Aggregation, 0, length)
var err error
@ -366,7 +366,7 @@ func (h *rpcHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]agg
return aggs, groupBys, relatedColOffsets, nil
}
func (h *rpcHandler) buildHashAgg(ctx *dagContext, executor *tipb.Executor) (*hashAggExec, error) {
func (h coprHandler) buildHashAgg(ctx *dagContext, executor *tipb.Executor) (*hashAggExec, error) {
aggs, groupBys, relatedColOffsets, err := h.getAggInfo(ctx, executor)
if err != nil {
return nil, errors.Trace(err)
@ -384,7 +384,7 @@ func (h *rpcHandler) buildHashAgg(ctx *dagContext, executor *tipb.Executor) (*ha
}, nil
}
func (h *rpcHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (*streamAggExec, error) {
func (h coprHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (*streamAggExec, error) {
aggs, groupBys, relatedColOffsets, err := h.getAggInfo(ctx, executor)
if err != nil {
return nil, errors.Trace(err)
@ -406,7 +406,7 @@ func (h *rpcHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (*
}, nil
}
func (h *rpcHandler) buildTopN(ctx *dagContext, executor *tipb.Executor) (*topNExec, error) {
func (h coprHandler) buildTopN(ctx *dagContext, executor *tipb.Executor) (*topNExec, error) {
topN := executor.TopN
var err error
var relatedColOffsets []int
@ -664,7 +664,7 @@ func (mock *mockCopStreamClient) readBlockFromExecutor() (tipb.Chunk, bool, *cop
return chunk, finish, &ran, mock.exec.Counts(), warnings, nil
}
func (h *rpcHandler) initSelectResponse(err error, warnings []stmtctx.SQLWarn, counts []int64) *tipb.SelectResponse {
func (h coprHandler) initSelectResponse(err error, warnings []stmtctx.SQLWarn, counts []int64) *tipb.SelectResponse {
selResp := &tipb.SelectResponse{
Error: toPBError(err),
OutputCounts: counts,
@ -675,7 +675,7 @@ func (h *rpcHandler) initSelectResponse(err error, warnings []stmtctx.SQLWarn, c
return selResp
}
func (h *rpcHandler) fillUpData4SelectResponse(selResp *tipb.SelectResponse, dagReq *tipb.DAGRequest, dagCtx *dagContext, rows [][][]byte) error {
func (h coprHandler) fillUpData4SelectResponse(selResp *tipb.SelectResponse, dagReq *tipb.DAGRequest, dagCtx *dagContext, rows [][][]byte) error {
switch dagReq.EncodeType {
case tipb.EncodeType_TypeDefault:
h.encodeDefault(selResp, rows, dagReq.OutputOffsets)
@ -690,7 +690,7 @@ func (h *rpcHandler) fillUpData4SelectResponse(selResp *tipb.SelectResponse, dag
return nil
}
func (h *rpcHandler) constructRespSchema(dagCtx *dagContext) []*types.FieldType {
func (h coprHandler) constructRespSchema(dagCtx *dagContext) []*types.FieldType {
var root *tipb.Executor
if len(dagCtx.dagReq.Executors) == 0 {
root = dagCtx.dagReq.RootExecutor
@ -717,7 +717,7 @@ func (h *rpcHandler) constructRespSchema(dagCtx *dagContext) []*types.FieldType
return schema
}
func (h *rpcHandler) encodeDefault(selResp *tipb.SelectResponse, rows [][][]byte, colOrdinal []uint32) {
func (h coprHandler) encodeDefault(selResp *tipb.SelectResponse, rows [][][]byte, colOrdinal []uint32) {
var chunks []tipb.Chunk
for i := range rows {
requestedRow := dummySlice
@ -730,7 +730,7 @@ func (h *rpcHandler) encodeDefault(selResp *tipb.SelectResponse, rows [][][]byte
selResp.EncodeType = tipb.EncodeType_TypeDefault
}
func (h *rpcHandler) encodeChunk(selResp *tipb.SelectResponse, rows [][][]byte, colTypes []*types.FieldType, colOrdinal []uint32, loc *time.Location) error {
func (h coprHandler) encodeChunk(selResp *tipb.SelectResponse, rows [][][]byte, colTypes []*types.FieldType, colOrdinal []uint32, loc *time.Location) error {
var chunks []tipb.Chunk
respColTypes := make([]*types.FieldType, 0, len(colOrdinal))
for _, ordinal := range colOrdinal {
@ -826,7 +826,7 @@ func toPBError(err error) *tipb.Error {
}
// extractKVRanges extracts kv.KeyRanges slice from a SelectRequest.
func (h *rpcHandler) extractKVRanges(keyRanges []*coprocessor.KeyRange, descScan bool) (kvRanges []kv.KeyRange, err error) {
func (h coprHandler) extractKVRanges(keyRanges []*coprocessor.KeyRange, descScan bool) (kvRanges []kv.KeyRange, err error) {
for _, kran := range keyRanges {
if bytes.Compare(kran.GetStart(), kran.GetEnd()) >= 0 {
err = errors.Errorf("invalid range, start should be smaller than end: %v %v", kran.GetStart(), kran.GetEnd())

View File

@ -22,7 +22,6 @@ import (
"sync"
"time"
"github.com/golang/protobuf/proto"
"github.com/opentracing/opentracing-go"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
@ -32,7 +31,6 @@ import (
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/ddl/placement"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/store/tikv/tikvrpc"
"github.com/pingcap/tipb/go-tipb"
@ -141,132 +139,13 @@ func convertToPbPairs(pairs []Pair) []*kvrpcpb.KvPair {
return kvPairs
}
// rpcHandler mocks tikv's side handler behavior. In general, you may assume
// kvHandler mocks tikv's side handler behavior. In general, you may assume
// TiKV just translate the logic from Go to Rust.
type rpcHandler struct {
cluster *Cluster
mvccStore MVCCStore
// storeID stores id for current request
storeID uint64
// startKey is used for handling normal request.
startKey []byte
endKey []byte
// rawStartKey is used for handling coprocessor request.
rawStartKey []byte
rawEndKey []byte
// isolationLevel is used for current request.
isolationLevel kvrpcpb.IsolationLevel
resolvedLocks []uint64
type kvHandler struct {
*Session
}
func isTiFlashStore(store *metapb.Store) bool {
for _, l := range store.GetLabels() {
if l.GetKey() == placement.EngineLabelKey && l.GetValue() == placement.EngineLabelTiFlash {
return true
}
}
return false
}
func (h *rpcHandler) checkRequestContext(ctx *kvrpcpb.Context) *errorpb.Error {
ctxPeer := ctx.GetPeer()
if ctxPeer != nil && ctxPeer.GetStoreId() != h.storeID {
return &errorpb.Error{
Message: *proto.String("store not match"),
StoreNotMatch: &errorpb.StoreNotMatch{},
}
}
region, leaderID := h.cluster.GetRegion(ctx.GetRegionId())
// No region found.
if region == nil {
return &errorpb.Error{
Message: *proto.String("region not found"),
RegionNotFound: &errorpb.RegionNotFound{
RegionId: *proto.Uint64(ctx.GetRegionId()),
},
}
}
var storePeer, leaderPeer *metapb.Peer
for _, p := range region.Peers {
if p.GetStoreId() == h.storeID {
storePeer = p
}
if p.GetId() == leaderID {
leaderPeer = p
}
}
// The Store does not contain a Peer of the Region.
if storePeer == nil {
return &errorpb.Error{
Message: *proto.String("region not found"),
RegionNotFound: &errorpb.RegionNotFound{
RegionId: *proto.Uint64(ctx.GetRegionId()),
},
}
}
// No leader.
if leaderPeer == nil {
return &errorpb.Error{
Message: *proto.String("no leader"),
NotLeader: &errorpb.NotLeader{
RegionId: *proto.Uint64(ctx.GetRegionId()),
},
}
}
// The Peer on the Store is not leader. If it's tiflash store , we pass this check.
if storePeer.GetId() != leaderPeer.GetId() && !isTiFlashStore(h.cluster.GetStore(storePeer.GetStoreId())) {
return &errorpb.Error{
Message: *proto.String("not leader"),
NotLeader: &errorpb.NotLeader{
RegionId: *proto.Uint64(ctx.GetRegionId()),
Leader: leaderPeer,
},
}
}
// Region epoch does not match.
if !proto.Equal(region.GetRegionEpoch(), ctx.GetRegionEpoch()) {
nextRegion, _ := h.cluster.GetRegionByKey(region.GetEndKey())
currentRegions := []*metapb.Region{region}
if nextRegion != nil {
currentRegions = append(currentRegions, nextRegion)
}
return &errorpb.Error{
Message: *proto.String("epoch not match"),
EpochNotMatch: &errorpb.EpochNotMatch{
CurrentRegions: currentRegions,
},
}
}
h.startKey, h.endKey = region.StartKey, region.EndKey
h.isolationLevel = ctx.IsolationLevel
h.resolvedLocks = ctx.ResolvedLocks
return nil
}
func (h *rpcHandler) checkRequestSize(size int) *errorpb.Error {
// TiKV has a limitation on raft log size.
// mocktikv has no raft inside, so we check the request's size instead.
if size >= requestMaxSize {
return &errorpb.Error{
RaftEntryTooLarge: &errorpb.RaftEntryTooLarge{},
}
}
return nil
}
func (h *rpcHandler) checkRequest(ctx *kvrpcpb.Context, size int) *errorpb.Error {
if err := h.checkRequestContext(ctx); err != nil {
return err
}
return h.checkRequestSize(size)
}
func (h *rpcHandler) checkKeyInRegion(key []byte) bool {
return regionContains(h.startKey, h.endKey, NewMvccKey(key))
}
func (h *rpcHandler) handleKvGet(req *kvrpcpb.GetRequest) *kvrpcpb.GetResponse {
func (h kvHandler) handleKvGet(req *kvrpcpb.GetRequest) *kvrpcpb.GetResponse {
if !h.checkKeyInRegion(req.Key) {
panic("KvGet: key not in region")
}
@ -282,7 +161,7 @@ func (h *rpcHandler) handleKvGet(req *kvrpcpb.GetRequest) *kvrpcpb.GetResponse {
}
}
func (h *rpcHandler) handleKvScan(req *kvrpcpb.ScanRequest) *kvrpcpb.ScanResponse {
func (h kvHandler) handleKvScan(req *kvrpcpb.ScanRequest) *kvrpcpb.ScanResponse {
endKey := MvccKey(h.endKey).Raw()
var pairs []Pair
if !req.Reverse {
@ -314,7 +193,7 @@ func (h *rpcHandler) handleKvScan(req *kvrpcpb.ScanRequest) *kvrpcpb.ScanRespons
}
}
func (h *rpcHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.PrewriteResponse {
func (h kvHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.PrewriteResponse {
regionID := req.Context.RegionId
h.cluster.handleDelay(req.StartVersion, regionID)
@ -329,7 +208,7 @@ func (h *rpcHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.Pre
}
}
func (h *rpcHandler) handleKvPessimisticLock(req *kvrpcpb.PessimisticLockRequest) *kvrpcpb.PessimisticLockResponse {
func (h kvHandler) handleKvPessimisticLock(req *kvrpcpb.PessimisticLockRequest) *kvrpcpb.PessimisticLockResponse {
for _, m := range req.Mutations {
if !h.checkKeyInRegion(m.Key) {
panic("KvPessimisticLock: key not in region")
@ -350,7 +229,7 @@ func simulateServerSideWaitLock(errs []error) {
}
}
func (h *rpcHandler) handleKvPessimisticRollback(req *kvrpcpb.PessimisticRollbackRequest) *kvrpcpb.PessimisticRollbackResponse {
func (h kvHandler) handleKvPessimisticRollback(req *kvrpcpb.PessimisticRollbackRequest) *kvrpcpb.PessimisticRollbackResponse {
for _, key := range req.Keys {
if !h.checkKeyInRegion(key) {
panic("KvPessimisticRollback: key not in region")
@ -362,7 +241,7 @@ func (h *rpcHandler) handleKvPessimisticRollback(req *kvrpcpb.PessimisticRollbac
}
}
func (h *rpcHandler) handleKvCommit(req *kvrpcpb.CommitRequest) *kvrpcpb.CommitResponse {
func (h kvHandler) handleKvCommit(req *kvrpcpb.CommitRequest) *kvrpcpb.CommitResponse {
for _, k := range req.Keys {
if !h.checkKeyInRegion(k) {
panic("KvCommit: key not in region")
@ -376,7 +255,7 @@ func (h *rpcHandler) handleKvCommit(req *kvrpcpb.CommitRequest) *kvrpcpb.CommitR
return &resp
}
func (h *rpcHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.CleanupResponse {
func (h kvHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.CleanupResponse {
if !h.checkKeyInRegion(req.Key) {
panic("KvCleanup: key not in region")
}
@ -392,7 +271,7 @@ func (h *rpcHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.Clean
return &resp
}
func (h *rpcHandler) handleKvCheckTxnStatus(req *kvrpcpb.CheckTxnStatusRequest) *kvrpcpb.CheckTxnStatusResponse {
func (h kvHandler) handleKvCheckTxnStatus(req *kvrpcpb.CheckTxnStatusRequest) *kvrpcpb.CheckTxnStatusResponse {
if !h.checkKeyInRegion(req.PrimaryKey) {
panic("KvCheckTxnStatus: key not in region")
}
@ -406,7 +285,7 @@ func (h *rpcHandler) handleKvCheckTxnStatus(req *kvrpcpb.CheckTxnStatusRequest)
return &resp
}
func (h *rpcHandler) handleTxnHeartBeat(req *kvrpcpb.TxnHeartBeatRequest) *kvrpcpb.TxnHeartBeatResponse {
func (h kvHandler) handleTxnHeartBeat(req *kvrpcpb.TxnHeartBeatRequest) *kvrpcpb.TxnHeartBeatResponse {
if !h.checkKeyInRegion(req.PrimaryLock) {
panic("KvTxnHeartBeat: key not in region")
}
@ -419,7 +298,7 @@ func (h *rpcHandler) handleTxnHeartBeat(req *kvrpcpb.TxnHeartBeatRequest) *kvrpc
return &resp
}
func (h *rpcHandler) handleKvBatchGet(req *kvrpcpb.BatchGetRequest) *kvrpcpb.BatchGetResponse {
func (h kvHandler) handleKvBatchGet(req *kvrpcpb.BatchGetRequest) *kvrpcpb.BatchGetResponse {
for _, k := range req.Keys {
if !h.checkKeyInRegion(k) {
panic("KvBatchGet: key not in region")
@ -431,7 +310,7 @@ func (h *rpcHandler) handleKvBatchGet(req *kvrpcpb.BatchGetRequest) *kvrpcpb.Bat
}
}
func (h *rpcHandler) handleMvccGetByKey(req *kvrpcpb.MvccGetByKeyRequest) *kvrpcpb.MvccGetByKeyResponse {
func (h kvHandler) handleMvccGetByKey(req *kvrpcpb.MvccGetByKeyRequest) *kvrpcpb.MvccGetByKeyResponse {
debugger, ok := h.mvccStore.(MVCCDebugger)
if !ok {
return &kvrpcpb.MvccGetByKeyResponse{
@ -447,7 +326,7 @@ func (h *rpcHandler) handleMvccGetByKey(req *kvrpcpb.MvccGetByKeyRequest) *kvrpc
return &resp
}
func (h *rpcHandler) handleMvccGetByStartTS(req *kvrpcpb.MvccGetByStartTsRequest) *kvrpcpb.MvccGetByStartTsResponse {
func (h kvHandler) handleMvccGetByStartTS(req *kvrpcpb.MvccGetByStartTsRequest) *kvrpcpb.MvccGetByStartTsResponse {
debugger, ok := h.mvccStore.(MVCCDebugger)
if !ok {
return &kvrpcpb.MvccGetByStartTsResponse{
@ -459,7 +338,7 @@ func (h *rpcHandler) handleMvccGetByStartTS(req *kvrpcpb.MvccGetByStartTsRequest
return &resp
}
func (h *rpcHandler) handleKvBatchRollback(req *kvrpcpb.BatchRollbackRequest) *kvrpcpb.BatchRollbackResponse {
func (h kvHandler) handleKvBatchRollback(req *kvrpcpb.BatchRollbackRequest) *kvrpcpb.BatchRollbackResponse {
err := h.mvccStore.Rollback(req.Keys, req.StartVersion)
if err != nil {
return &kvrpcpb.BatchRollbackResponse{
@ -469,7 +348,7 @@ func (h *rpcHandler) handleKvBatchRollback(req *kvrpcpb.BatchRollbackRequest) *k
return &kvrpcpb.BatchRollbackResponse{}
}
func (h *rpcHandler) handleKvScanLock(req *kvrpcpb.ScanLockRequest) *kvrpcpb.ScanLockResponse {
func (h kvHandler) handleKvScanLock(req *kvrpcpb.ScanLockRequest) *kvrpcpb.ScanLockResponse {
startKey := MvccKey(h.startKey).Raw()
endKey := MvccKey(h.endKey).Raw()
locks, err := h.mvccStore.ScanLock(startKey, endKey, req.GetMaxVersion())
@ -483,7 +362,7 @@ func (h *rpcHandler) handleKvScanLock(req *kvrpcpb.ScanLockRequest) *kvrpcpb.Sca
}
}
func (h *rpcHandler) handleKvResolveLock(req *kvrpcpb.ResolveLockRequest) *kvrpcpb.ResolveLockResponse {
func (h kvHandler) handleKvResolveLock(req *kvrpcpb.ResolveLockRequest) *kvrpcpb.ResolveLockResponse {
startKey := MvccKey(h.startKey).Raw()
endKey := MvccKey(h.endKey).Raw()
err := h.mvccStore.ResolveLock(startKey, endKey, req.GetStartVersion(), req.GetCommitVersion())
@ -495,7 +374,7 @@ func (h *rpcHandler) handleKvResolveLock(req *kvrpcpb.ResolveLockRequest) *kvrpc
return &kvrpcpb.ResolveLockResponse{}
}
func (h *rpcHandler) handleKvGC(req *kvrpcpb.GCRequest) *kvrpcpb.GCResponse {
func (h kvHandler) handleKvGC(req *kvrpcpb.GCRequest) *kvrpcpb.GCResponse {
startKey := MvccKey(h.startKey).Raw()
endKey := MvccKey(h.endKey).Raw()
err := h.mvccStore.GC(startKey, endKey, req.GetSafePoint())
@ -507,7 +386,7 @@ func (h *rpcHandler) handleKvGC(req *kvrpcpb.GCRequest) *kvrpcpb.GCResponse {
return &kvrpcpb.GCResponse{}
}
func (h *rpcHandler) handleKvDeleteRange(req *kvrpcpb.DeleteRangeRequest) *kvrpcpb.DeleteRangeResponse {
func (h kvHandler) handleKvDeleteRange(req *kvrpcpb.DeleteRangeRequest) *kvrpcpb.DeleteRangeResponse {
if !h.checkKeyInRegion(req.StartKey) {
panic("KvDeleteRange: key not in region")
}
@ -519,7 +398,7 @@ func (h *rpcHandler) handleKvDeleteRange(req *kvrpcpb.DeleteRangeRequest) *kvrpc
return &resp
}
func (h *rpcHandler) handleKvRawGet(req *kvrpcpb.RawGetRequest) *kvrpcpb.RawGetResponse {
func (h kvHandler) handleKvRawGet(req *kvrpcpb.RawGetRequest) *kvrpcpb.RawGetResponse {
rawKV, ok := h.mvccStore.(RawKV)
if !ok {
return &kvrpcpb.RawGetResponse{
@ -531,7 +410,7 @@ func (h *rpcHandler) handleKvRawGet(req *kvrpcpb.RawGetRequest) *kvrpcpb.RawGetR
}
}
func (h *rpcHandler) handleKvRawBatchGet(req *kvrpcpb.RawBatchGetRequest) *kvrpcpb.RawBatchGetResponse {
func (h kvHandler) handleKvRawBatchGet(req *kvrpcpb.RawBatchGetRequest) *kvrpcpb.RawBatchGetResponse {
rawKV, ok := h.mvccStore.(RawKV)
if !ok {
// TODO should we add error ?
@ -554,7 +433,7 @@ func (h *rpcHandler) handleKvRawBatchGet(req *kvrpcpb.RawBatchGetRequest) *kvrpc
}
}
func (h *rpcHandler) handleKvRawPut(req *kvrpcpb.RawPutRequest) *kvrpcpb.RawPutResponse {
func (h kvHandler) handleKvRawPut(req *kvrpcpb.RawPutRequest) *kvrpcpb.RawPutResponse {
rawKV, ok := h.mvccStore.(RawKV)
if !ok {
return &kvrpcpb.RawPutResponse{
@ -565,7 +444,7 @@ func (h *rpcHandler) handleKvRawPut(req *kvrpcpb.RawPutRequest) *kvrpcpb.RawPutR
return &kvrpcpb.RawPutResponse{}
}
func (h *rpcHandler) handleKvRawBatchPut(req *kvrpcpb.RawBatchPutRequest) *kvrpcpb.RawBatchPutResponse {
func (h kvHandler) handleKvRawBatchPut(req *kvrpcpb.RawBatchPutRequest) *kvrpcpb.RawBatchPutResponse {
rawKV, ok := h.mvccStore.(RawKV)
if !ok {
return &kvrpcpb.RawBatchPutResponse{
@ -582,7 +461,7 @@ func (h *rpcHandler) handleKvRawBatchPut(req *kvrpcpb.RawBatchPutRequest) *kvrpc
return &kvrpcpb.RawBatchPutResponse{}
}
func (h *rpcHandler) handleKvRawDelete(req *kvrpcpb.RawDeleteRequest) *kvrpcpb.RawDeleteResponse {
func (h kvHandler) handleKvRawDelete(req *kvrpcpb.RawDeleteRequest) *kvrpcpb.RawDeleteResponse {
rawKV, ok := h.mvccStore.(RawKV)
if !ok {
return &kvrpcpb.RawDeleteResponse{
@ -593,7 +472,7 @@ func (h *rpcHandler) handleKvRawDelete(req *kvrpcpb.RawDeleteRequest) *kvrpcpb.R
return &kvrpcpb.RawDeleteResponse{}
}
func (h *rpcHandler) handleKvRawBatchDelete(req *kvrpcpb.RawBatchDeleteRequest) *kvrpcpb.RawBatchDeleteResponse {
func (h kvHandler) handleKvRawBatchDelete(req *kvrpcpb.RawBatchDeleteRequest) *kvrpcpb.RawBatchDeleteResponse {
rawKV, ok := h.mvccStore.(RawKV)
if !ok {
return &kvrpcpb.RawBatchDeleteResponse{
@ -604,7 +483,7 @@ func (h *rpcHandler) handleKvRawBatchDelete(req *kvrpcpb.RawBatchDeleteRequest)
return &kvrpcpb.RawBatchDeleteResponse{}
}
func (h *rpcHandler) handleKvRawDeleteRange(req *kvrpcpb.RawDeleteRangeRequest) *kvrpcpb.RawDeleteRangeResponse {
func (h kvHandler) handleKvRawDeleteRange(req *kvrpcpb.RawDeleteRangeRequest) *kvrpcpb.RawDeleteRangeResponse {
rawKV, ok := h.mvccStore.(RawKV)
if !ok {
return &kvrpcpb.RawDeleteRangeResponse{
@ -615,7 +494,7 @@ func (h *rpcHandler) handleKvRawDeleteRange(req *kvrpcpb.RawDeleteRangeRequest)
return &kvrpcpb.RawDeleteRangeResponse{}
}
func (h *rpcHandler) handleKvRawScan(req *kvrpcpb.RawScanRequest) *kvrpcpb.RawScanResponse {
func (h kvHandler) handleKvRawScan(req *kvrpcpb.RawScanRequest) *kvrpcpb.RawScanResponse {
rawKV, ok := h.mvccStore.(RawKV)
if !ok {
errStr := "not implemented"
@ -654,7 +533,7 @@ func (h *rpcHandler) handleKvRawScan(req *kvrpcpb.RawScanRequest) *kvrpcpb.RawSc
}
}
func (h *rpcHandler) handleSplitRegion(req *kvrpcpb.SplitRegionRequest) *kvrpcpb.SplitRegionResponse {
func (h kvHandler) handleSplitRegion(req *kvrpcpb.SplitRegionRequest) *kvrpcpb.SplitRegionResponse {
keys := req.GetSplitKeys()
resp := &kvrpcpb.SplitRegionResponse{Regions: make([]*metapb.Region, 0, len(keys)+1)}
for i, key := range keys {
@ -690,7 +569,11 @@ func drainRowsFromExecutor(ctx context.Context, e executor, req *tipb.DAGRequest
}
}
func (h *rpcHandler) handleBatchCopRequest(ctx context.Context, req *coprocessor.BatchRequest) (*mockBatchCopDataClient, error) {
type coprHandler struct {
*Session
}
func (h coprHandler) handleBatchCopRequest(ctx context.Context, req *coprocessor.BatchRequest) (*mockBatchCopDataClient, error) {
client := &mockBatchCopDataClient{}
for _, ri := range req.Regions {
cop := coprocessor.Request{
@ -766,7 +649,7 @@ func (c *RPCClient) getAndCheckStoreByAddr(addr string) (*metapb.Store, error) {
return nil, errors.New("connection refused")
}
func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*rpcHandler, error) {
func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*Session, error) {
if err := checkGoContext(ctx); err != nil {
return nil, err
}
@ -775,13 +658,13 @@ func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*rpcHandler, er
if err != nil {
return nil, err
}
handler := &rpcHandler{
session := &Session{
cluster: c.Cluster,
mvccStore: c.MvccStore,
// set store id for current request
storeID: store.GetId(),
}
return handler, nil
return session, nil
}
// GRPCClientFactory is the GRPC client factory.
@ -828,25 +711,25 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
return c.redirectRequestToRPCServer(ctx, addr, req, timeout)
}
handler, err := c.checkArgs(ctx, addr)
session, err := c.checkArgs(ctx, addr)
if err != nil {
return nil, err
}
switch req.Type {
case tikvrpc.CmdGet:
r := req.Get()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.GetResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvGet(r)
resp.Resp = kvHandler{session}.handleKvGet(r)
case tikvrpc.CmdScan:
r := req.Scan()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.ScanResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvScan(r)
resp.Resp = kvHandler{session}.handleKvScan(r)
case tikvrpc.CmdPrewrite:
failpoint.Inject("rpcPrewriteResult", func(val failpoint.Value) {
@ -859,25 +742,25 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
})
r := req.Prewrite()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.PrewriteResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvPrewrite(r)
resp.Resp = kvHandler{session}.handleKvPrewrite(r)
case tikvrpc.CmdPessimisticLock:
r := req.PessimisticLock()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.PessimisticLockResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvPessimisticLock(r)
resp.Resp = kvHandler{session}.handleKvPessimisticLock(r)
case tikvrpc.CmdPessimisticRollback:
r := req.PessimisticRollback()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.PessimisticRollbackResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvPessimisticRollback(r)
resp.Resp = kvHandler{session}.handleKvPessimisticRollback(r)
case tikvrpc.CmdCommit:
failpoint.Inject("rpcCommitResult", func(val failpoint.Value) {
switch val.(string) {
@ -895,11 +778,11 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
})
r := req.Commit()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.CommitResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvCommit(r)
resp.Resp = kvHandler{session}.handleKvCommit(r)
failpoint.Inject("rpcCommitTimeout", func(val failpoint.Value) {
if val.(bool) {
failpoint.Return(nil, undeterminedErr)
@ -907,122 +790,122 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
})
case tikvrpc.CmdCleanup:
r := req.Cleanup()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.CleanupResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvCleanup(r)
resp.Resp = kvHandler{session}.handleKvCleanup(r)
case tikvrpc.CmdCheckTxnStatus:
r := req.CheckTxnStatus()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.CheckTxnStatusResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvCheckTxnStatus(r)
resp.Resp = kvHandler{session}.handleKvCheckTxnStatus(r)
case tikvrpc.CmdTxnHeartBeat:
r := req.TxnHeartBeat()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.TxnHeartBeatResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleTxnHeartBeat(r)
resp.Resp = kvHandler{session}.handleTxnHeartBeat(r)
case tikvrpc.CmdBatchGet:
r := req.BatchGet()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.BatchGetResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvBatchGet(r)
resp.Resp = kvHandler{session}.handleKvBatchGet(r)
case tikvrpc.CmdBatchRollback:
r := req.BatchRollback()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.BatchRollbackResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvBatchRollback(r)
resp.Resp = kvHandler{session}.handleKvBatchRollback(r)
case tikvrpc.CmdScanLock:
r := req.ScanLock()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.ScanLockResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvScanLock(r)
resp.Resp = kvHandler{session}.handleKvScanLock(r)
case tikvrpc.CmdResolveLock:
r := req.ResolveLock()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.ResolveLockResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvResolveLock(r)
resp.Resp = kvHandler{session}.handleKvResolveLock(r)
case tikvrpc.CmdGC:
r := req.GC()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.GCResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvGC(r)
resp.Resp = kvHandler{session}.handleKvGC(r)
case tikvrpc.CmdDeleteRange:
r := req.DeleteRange()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.DeleteRangeResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvDeleteRange(r)
resp.Resp = kvHandler{session}.handleKvDeleteRange(r)
case tikvrpc.CmdRawGet:
r := req.RawGet()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.RawGetResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvRawGet(r)
resp.Resp = kvHandler{session}.handleKvRawGet(r)
case tikvrpc.CmdRawBatchGet:
r := req.RawBatchGet()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.RawBatchGetResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvRawBatchGet(r)
resp.Resp = kvHandler{session}.handleKvRawBatchGet(r)
case tikvrpc.CmdRawPut:
r := req.RawPut()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.RawPutResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvRawPut(r)
resp.Resp = kvHandler{session}.handleKvRawPut(r)
case tikvrpc.CmdRawBatchPut:
r := req.RawBatchPut()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.RawBatchPutResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvRawBatchPut(r)
resp.Resp = kvHandler{session}.handleKvRawBatchPut(r)
case tikvrpc.CmdRawDelete:
r := req.RawDelete()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.RawDeleteResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvRawDelete(r)
resp.Resp = kvHandler{session}.handleKvRawDelete(r)
case tikvrpc.CmdRawBatchDelete:
r := req.RawBatchDelete()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.RawBatchDeleteResponse{RegionError: err}
}
resp.Resp = handler.handleKvRawBatchDelete(r)
resp.Resp = kvHandler{session}.handleKvRawBatchDelete(r)
case tikvrpc.CmdRawDeleteRange:
r := req.RawDeleteRange()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.RawDeleteRangeResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvRawDeleteRange(r)
resp.Resp = kvHandler{session}.handleKvRawDeleteRange(r)
case tikvrpc.CmdRawScan:
r := req.RawScan()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.RawScanResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleKvRawScan(r)
resp.Resp = kvHandler{session}.handleKvRawScan(r)
case tikvrpc.CmdUnsafeDestroyRange:
panic("unimplemented")
case tikvrpc.CmdRegisterLockObserver:
@ -1035,20 +918,20 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
return nil, errors.New("unimplemented")
case tikvrpc.CmdCop:
r := req.Cop()
if err := handler.checkRequestContext(reqCtx); err != nil {
if err := session.checkRequestContext(reqCtx); err != nil {
resp.Resp = &coprocessor.Response{RegionError: err}
return resp, nil
}
handler.rawStartKey = MvccKey(handler.startKey).Raw()
handler.rawEndKey = MvccKey(handler.endKey).Raw()
session.rawStartKey = MvccKey(session.startKey).Raw()
session.rawEndKey = MvccKey(session.endKey).Raw()
var res *coprocessor.Response
switch r.GetTp() {
case kv.ReqTypeDAG:
res = handler.handleCopDAGRequest(r)
res = coprHandler{session}.handleCopDAGRequest(r)
case kv.ReqTypeAnalyze:
res = handler.handleCopAnalyzeRequest(r)
res = coprHandler{session}.handleCopAnalyzeRequest(r)
case kv.ReqTypeChecksum:
res = handler.handleCopChecksumRequest(r)
res = coprHandler{session}.handleCopChecksumRequest(r)
default:
panic(fmt.Sprintf("unknown coprocessor request type: %v", r.GetTp()))
}
@ -1066,7 +949,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
}
})
r := req.BatchCop()
if err := handler.checkRequestContext(reqCtx); err != nil {
if err := session.checkRequestContext(reqCtx); err != nil {
resp.Resp = &tikvrpc.BatchCopStreamResponse{
Tikv_BatchCoprocessorClient: &mockBathCopErrClient{Error: err},
BatchResponse: &coprocessor.BatchResponse{
@ -1076,7 +959,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
return resp, nil
}
ctx1, cancel := context.WithCancel(ctx)
batchCopStream, err := handler.handleBatchCopRequest(ctx1, r)
batchCopStream, err := coprHandler{session}.handleBatchCopRequest(ctx1, r)
if err != nil {
cancel()
return nil, errors.Trace(err)
@ -1094,7 +977,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
resp.Resp = batchResp
case tikvrpc.CmdCopStream:
r := req.Cop()
if err := handler.checkRequestContext(reqCtx); err != nil {
if err := session.checkRequestContext(reqCtx); err != nil {
resp.Resp = &tikvrpc.CopStreamResponse{
Tikv_CoprocessorStreamClient: &mockCopStreamErrClient{Error: err},
Response: &coprocessor.Response{
@ -1103,10 +986,10 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
}
return resp, nil
}
handler.rawStartKey = MvccKey(handler.startKey).Raw()
handler.rawEndKey = MvccKey(handler.endKey).Raw()
session.rawStartKey = MvccKey(session.startKey).Raw()
session.rawEndKey = MvccKey(session.endKey).Raw()
ctx1, cancel := context.WithCancel(ctx)
copStream, err := handler.handleCopStream(ctx1, r)
copStream, err := coprHandler{session}.handleCopStream(ctx1, r)
if err != nil {
cancel()
return nil, errors.Trace(err)
@ -1127,31 +1010,31 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R
resp.Resp = streamResp
case tikvrpc.CmdMvccGetByKey:
r := req.MvccGetByKey()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.MvccGetByKeyResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleMvccGetByKey(r)
resp.Resp = kvHandler{session}.handleMvccGetByKey(r)
case tikvrpc.CmdMvccGetByStartTs:
r := req.MvccGetByStartTs()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.MvccGetByStartTsResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleMvccGetByStartTS(r)
resp.Resp = kvHandler{session}.handleMvccGetByStartTS(r)
case tikvrpc.CmdSplitRegion:
r := req.SplitRegion()
if err := handler.checkRequest(reqCtx, r.Size()); err != nil {
if err := session.checkRequest(reqCtx, r.Size()); err != nil {
resp.Resp = &kvrpcpb.SplitRegionResponse{RegionError: err}
return resp, nil
}
resp.Resp = handler.handleSplitRegion(r)
resp.Resp = kvHandler{session}.handleSplitRegion(r)
// DebugGetRegionProperties is for fast analyze in mock tikv.
case tikvrpc.CmdDebugGetRegionProperties:
r := req.DebugGetRegionProperties()
region, _ := c.Cluster.GetRegion(r.RegionId)
var reqCtx kvrpcpb.Context
scanResp := handler.handleKvScan(&kvrpcpb.ScanRequest{
scanResp := kvHandler{session}.handleKvScan(&kvrpcpb.ScanRequest{
Context: &reqCtx,
StartKey: MvccKey(region.StartKey).Raw(),
EndKey: MvccKey(region.EndKey).Raw(),

View File

@ -0,0 +1,146 @@
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package mocktikv
import (
"github.com/gogo/protobuf/proto"
"github.com/pingcap/kvproto/pkg/errorpb"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/tidb/ddl/placement"
)
// Session stores session scope rpc data.
type Session struct {
cluster *Cluster
mvccStore MVCCStore
// storeID stores id for current request
storeID uint64
// startKey is used for handling normal request.
startKey []byte
endKey []byte
// rawStartKey is used for handling coprocessor request.
rawStartKey []byte
rawEndKey []byte
// isolationLevel is used for current request.
isolationLevel kvrpcpb.IsolationLevel
resolvedLocks []uint64
}
func (s *Session) checkRequestContext(ctx *kvrpcpb.Context) *errorpb.Error {
ctxPeer := ctx.GetPeer()
if ctxPeer != nil && ctxPeer.GetStoreId() != s.storeID {
return &errorpb.Error{
Message: *proto.String("store not match"),
StoreNotMatch: &errorpb.StoreNotMatch{},
}
}
region, leaderID := s.cluster.GetRegion(ctx.GetRegionId())
// No region found.
if region == nil {
return &errorpb.Error{
Message: *proto.String("region not found"),
RegionNotFound: &errorpb.RegionNotFound{
RegionId: *proto.Uint64(ctx.GetRegionId()),
},
}
}
var storePeer, leaderPeer *metapb.Peer
for _, p := range region.Peers {
if p.GetStoreId() == s.storeID {
storePeer = p
}
if p.GetId() == leaderID {
leaderPeer = p
}
}
// The Store does not contain a Peer of the Region.
if storePeer == nil {
return &errorpb.Error{
Message: *proto.String("region not found"),
RegionNotFound: &errorpb.RegionNotFound{
RegionId: *proto.Uint64(ctx.GetRegionId()),
},
}
}
// No leader.
if leaderPeer == nil {
return &errorpb.Error{
Message: *proto.String("no leader"),
NotLeader: &errorpb.NotLeader{
RegionId: *proto.Uint64(ctx.GetRegionId()),
},
}
}
// The Peer on the Store is not leader. If it's tiflash store , we pass this check.
if storePeer.GetId() != leaderPeer.GetId() && !isTiFlashStore(s.cluster.GetStore(storePeer.GetStoreId())) {
return &errorpb.Error{
Message: *proto.String("not leader"),
NotLeader: &errorpb.NotLeader{
RegionId: *proto.Uint64(ctx.GetRegionId()),
Leader: leaderPeer,
},
}
}
// Region epoch does not match.
if !proto.Equal(region.GetRegionEpoch(), ctx.GetRegionEpoch()) {
nextRegion, _ := s.cluster.GetRegionByKey(region.GetEndKey())
currentRegions := []*metapb.Region{region}
if nextRegion != nil {
currentRegions = append(currentRegions, nextRegion)
}
return &errorpb.Error{
Message: *proto.String("epoch not match"),
EpochNotMatch: &errorpb.EpochNotMatch{
CurrentRegions: currentRegions,
},
}
}
s.startKey, s.endKey = region.StartKey, region.EndKey
s.isolationLevel = ctx.IsolationLevel
s.resolvedLocks = ctx.ResolvedLocks
return nil
}
func (s *Session) checkRequestSize(size int) *errorpb.Error {
// TiKV has a limitation on raft log size.
// mocktikv has no raft inside, so we check the request's size instead.
if size >= requestMaxSize {
return &errorpb.Error{
RaftEntryTooLarge: &errorpb.RaftEntryTooLarge{},
}
}
return nil
}
func (s *Session) checkRequest(ctx *kvrpcpb.Context, size int) *errorpb.Error {
if err := s.checkRequestContext(ctx); err != nil {
return err
}
return s.checkRequestSize(size)
}
func (s *Session) checkKeyInRegion(key []byte) bool {
return regionContains(s.startKey, s.endKey, NewMvccKey(key))
}
func isTiFlashStore(store *metapb.Store) bool {
for _, l := range store.GetLabels() {
if l.GetKey() == placement.EngineLabelKey && l.GetValue() == placement.EngineLabelTiFlash {
return true
}
}
return false
}