diff --git a/store/mockstore/mocktikv/aggregate.go b/store/mockstore/mocktikv/aggregate.go index cb81305493..d9d988a740 100644 --- a/store/mockstore/mocktikv/aggregate.go +++ b/store/mockstore/mocktikv/aggregate.go @@ -54,6 +54,11 @@ func (e *hashAggExec) GetSrcExec() executor { return e.src } +func (e *hashAggExec) ResetCount() { + e.count = 0 + e.src.ResetCount() +} + func (e *hashAggExec) Count() int64 { return e.count } @@ -73,6 +78,10 @@ func (e *hashAggExec) innerNext(ctx context.Context) (bool, error) { return true, nil } +func (e *hashAggExec) Cursor() ([]byte, bool) { + panic("don't not use coprocessor streaming API for hash aggregation!") +} + func (e *hashAggExec) Next(ctx context.Context) (value [][]byte, err error) { e.count++ if e.aggCtxsMap == nil { @@ -205,6 +214,11 @@ func (e *streamAggExec) GetSrcExec() executor { return e.src } +func (e *streamAggExec) ResetCount() { + e.count = 0 + e.src.ResetCount() +} + func (e *streamAggExec) Count() int64 { return e.count } @@ -269,6 +283,10 @@ func (e *streamAggExec) meetNewGroup(row [][]byte) (bool, error) { return !firstGroup, nil } +func (e *streamAggExec) Cursor() ([]byte, bool) { + panic("don't not use coprocessor streaming API for stream aggregation!") +} + func (e *streamAggExec) Next(ctx context.Context) (retRow [][]byte, err error) { e.count++ if e.executed { diff --git a/store/mockstore/mocktikv/cop_handler_dag.go b/store/mockstore/mocktikv/cop_handler_dag.go index 68c36e47ca..3f3b641f9b 100644 --- a/store/mockstore/mocktikv/cop_handler_dag.go +++ b/store/mockstore/mocktikv/cop_handler_dag.go @@ -435,7 +435,9 @@ func (mock *mockCopStreamClient) Recv() (*coprocessor.Response, error) { } var resp coprocessor.Response - chunk, finish, err := mock.readBlockFromExecutor() + counts := make([]int64, len(mock.req.Executors)) + chunk, finish, ran, err := mock.readBlockFromExecutor(counts) + resp.Range = ran if err != nil { if locked, ok := errors.Cause(err).(*ErrLocked); ok { resp.Locked = &kvrpcpb.LockInfo{ @@ -460,9 +462,10 @@ func (mock *mockCopStreamClient) Recv() (*coprocessor.Response, error) { return &resp, nil } streamResponse := tipb.StreamResponse{ - Error: toPBError(err), - EncodeType: tipb.EncodeType_TypeDefault, - Data: data, + Error: toPBError(err), + EncodeType: tipb.EncodeType_TypeDefault, + Data: data, + OutputCounts: counts, } resp.Data, err = proto.Marshal(&streamResponse) if err != nil { @@ -471,21 +474,41 @@ func (mock *mockCopStreamClient) Recv() (*coprocessor.Response, error) { return &resp, nil } -func (mock *mockCopStreamClient) readBlockFromExecutor() (tipb.Chunk, bool, error) { +func (mock *mockCopStreamClient) readBlockFromExecutor(counts []int64) (tipb.Chunk, bool, *coprocessor.KeyRange, error) { var chunk tipb.Chunk + var ran coprocessor.KeyRange + var finish bool + var desc bool + mock.exec.ResetCount() + ran.Start, desc = mock.exec.Cursor() for count := 0; count < rowsPerChunk; count++ { row, err := mock.exec.Next(mock.ctx) if err != nil { - return chunk, false, errors.Trace(err) + return chunk, false, nil, errors.Trace(err) } if row == nil { - return chunk, true, nil + finish = true + break } for _, offset := range mock.req.OutputOffsets { chunk.RowsData = append(chunk.RowsData, row[offset]...) } } - return chunk, false, nil + + ran.End, _ = mock.exec.Cursor() + if desc { + ran.Start, ran.End = ran.End, ran.Start + } + e := mock.exec + for offset := len(mock.req.Executors) - 1; e != nil; e, offset = e.GetSrcExec(), offset-1 { + count := e.Count() + // Because the last call to `executor.Next` always returns a `nil`, so the actual count should be `Count - 1` + if finish { + count-- + } + counts[offset] = count + } + return chunk, finish, &ran, nil } func buildResp(chunks []tipb.Chunk, counts []int64, err error) *coprocessor.Response { diff --git a/store/mockstore/mocktikv/executor.go b/store/mockstore/mocktikv/executor.go index 267b32e83a..c5802eb6c7 100644 --- a/store/mockstore/mocktikv/executor.go +++ b/store/mockstore/mocktikv/executor.go @@ -43,8 +43,11 @@ var ( type executor interface { SetSrcExec(executor) GetSrcExec() executor + ResetCount() Count() int64 Next(ctx context.Context) ([][]byte, error) + // Cursor returns the key gonna to be scanned by the Next() function. + Cursor() (key []byte, desc bool) } type tableScanExec struct { @@ -69,10 +72,37 @@ func (e *tableScanExec) GetSrcExec() executor { return e.src } +func (e *tableScanExec) ResetCount() { + e.count = 0 +} + func (e *tableScanExec) Count() int64 { return e.count } +func (e *tableScanExec) Cursor() ([]byte, bool) { + if len(e.seekKey) > 0 { + return e.seekKey, e.Desc + } + + if e.cursor < len(e.kvRanges) { + ran := e.kvRanges[e.cursor] + if ran.IsPoint() { + return ran.StartKey, e.Desc + } + + if e.Desc { + return ran.EndKey, e.Desc + } + return ran.StartKey, e.Desc + } + + if e.Desc { + return e.kvRanges[len(e.kvRanges)-1].StartKey, e.Desc + } + return e.kvRanges[len(e.kvRanges)-1].EndKey, e.Desc +} + func (e *tableScanExec) Next(ctx context.Context) (value [][]byte, err error) { e.count++ for e.cursor < len(e.kvRanges) { @@ -200,6 +230,10 @@ func (e *indexScanExec) GetSrcExec() executor { return e.src } +func (e *indexScanExec) ResetCount() { + e.count = 0 +} + func (e *indexScanExec) Count() int64 { return e.count } @@ -208,6 +242,26 @@ func (e *indexScanExec) isUnique() bool { return e.Unique != nil && *e.Unique } +func (e *indexScanExec) Cursor() ([]byte, bool) { + if len(e.seekKey) > 0 { + return e.seekKey, e.Desc + } + if e.cursor < len(e.kvRanges) { + ran := e.kvRanges[e.cursor] + if ran.IsPoint() && e.isUnique() { + return ran.StartKey, e.Desc + } + if e.Desc { + return ran.EndKey, e.Desc + } + return ran.StartKey, e.Desc + } + if e.Desc { + return e.kvRanges[len(e.kvRanges)-1].StartKey, e.Desc + } + return e.kvRanges[len(e.kvRanges)-1].EndKey, e.Desc +} + func (e *indexScanExec) Next(ctx context.Context) (value [][]byte, err error) { e.count++ for e.cursor < len(e.kvRanges) { @@ -337,6 +391,11 @@ func (e *selectionExec) GetSrcExec() executor { return e.src } +func (e *selectionExec) ResetCount() { + e.count = 0 + e.src.ResetCount() +} + func (e *selectionExec) Count() int64 { return e.count } @@ -363,6 +422,10 @@ func evalBool(exprs []expression.Expression, row types.DatumRow, ctx *stmtctx.St return true, nil } +func (e *selectionExec) Cursor() ([]byte, bool) { + return e.src.Cursor() +} + func (e *selectionExec) Next(ctx context.Context) (value [][]byte, err error) { e.count++ for { @@ -409,6 +472,11 @@ func (e *topNExec) GetSrcExec() executor { return e.src } +func (e *topNExec) ResetCount() { + e.count = 0 + e.src.ResetCount() +} + func (e *topNExec) Count() int64 { return e.count } @@ -428,6 +496,10 @@ func (e *topNExec) innerNext(ctx context.Context) (bool, error) { return true, nil } +func (e *topNExec) Cursor() ([]byte, bool) { + panic("don't not use coprocessor streaming API for topN!") +} + func (e *topNExec) Next(ctx context.Context) (value [][]byte, err error) { e.count++ if !e.executed { @@ -493,10 +565,19 @@ func (e *limitExec) GetSrcExec() executor { return e.src } +func (e *limitExec) ResetCount() { + e.count = 0 + e.src.ResetCount() +} + func (e *limitExec) Count() int64 { return e.count } +func (e *limitExec) Cursor() ([]byte, bool) { + return e.src.Cursor() +} + func (e *limitExec) Next(ctx context.Context) (value [][]byte, err error) { e.count++ if e.cursor >= e.limit {