executor,server: re-implement the kill statement by checking the Next() function (#10841)

This commit is contained in:
tiancaiamao
2019-06-20 11:03:21 +08:00
committed by GitHub
parent e4204df780
commit 421de5ef20
16 changed files with 61 additions and 68 deletions

View File

@ -106,7 +106,7 @@ func (a *recordSet) Next(ctx context.Context, req *chunk.RecordBatch) error {
defer span1.Finish()
}
err := a.executor.Next(ctx, req)
err := Next(ctx, a.executor, req)
if err != nil {
a.lastErr = err
return err
@ -385,7 +385,7 @@ func (a *ExecStmt) handleNoDelayExecutor(ctx context.Context, e Executor) (sqlex
a.logAudit()
}()
err = e.Next(ctx, chunk.NewRecordBatch(newFirstChunk(e)))
err = Next(ctx, e, chunk.NewRecordBatch(newFirstChunk(e)))
if err != nil {
return nil, err
}

View File

@ -555,7 +555,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context) {
}
chk = input.chk
}
err = e.children[0].Next(ctx, chunk.NewRecordBatch(chk))
err = Next(ctx, e.children[0], chunk.NewRecordBatch(chk))
if err != nil {
e.finalOutputCh <- &AfFinalResult{err: err}
return
@ -681,7 +681,7 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro
func (e *HashAggExec) execute(ctx context.Context) (err error) {
inputIter := chunk.NewIterator4Chunk(e.childResult)
for {
err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult))
if err != nil {
return err
}
@ -870,7 +870,7 @@ func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Ch
return err
}
err = e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult))
err = Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult))
if err != nil {
return err
}

View File

@ -105,7 +105,7 @@ func (e *DeleteExec) deleteSingleTableByChunk(ctx context.Context) error {
for {
iter := chunk.NewIterator4Chunk(chk)
err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk))
if err != nil {
return err
}
@ -187,7 +187,7 @@ func (e *DeleteExec) deleteMultiTablesByChunk(ctx context.Context) error {
chk := newFirstChunk(e.children[0])
for {
iter := chunk.NewIterator4Chunk(chk)
err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk))
if err != nil {
return err
}

View File

@ -52,6 +52,7 @@ var (
ErrWrongObject = terror.ClassExecutor.New(mysql.ErrWrongObject, mysql.MySQLErrName[mysql.ErrWrongObject])
ErrRoleNotGranted = terror.ClassPrivilege.New(mysql.ErrRoleNotGranted, mysql.MySQLErrName[mysql.ErrRoleNotGranted])
ErrDeadlock = terror.ClassExecutor.New(mysql.ErrLockDeadlock, mysql.MySQLErrName[mysql.ErrLockDeadlock])
ErrQueryInterrupted = terror.ClassExecutor.New(mysql.ErrQueryInterrupted, mysql.MySQLErrName[mysql.ErrQueryInterrupted])
)
func init() {
@ -69,6 +70,7 @@ func init() {
mysql.ErrBadDB: mysql.ErrBadDB,
mysql.ErrWrongObject: mysql.ErrWrongObject,
mysql.ErrLockDeadlock: mysql.ErrLockDeadlock,
mysql.ErrQueryInterrupted: mysql.ErrQueryInterrupted,
}
terror.ErrClassToMySQLCodes[terror.ClassExecutor] = tableMySQLErrCodes
}

View File

@ -180,6 +180,16 @@ type Executor interface {
Schema() *expression.Schema
}
// Next is a wrapper function on e.Next(), it handles some common codes.
func Next(ctx context.Context, e Executor, req *chunk.RecordBatch) error {
sessVars := e.base().ctx.GetSessionVars()
if atomic.CompareAndSwapUint32(&sessVars.Killed, 1, 0) {
return ErrQueryInterrupted
}
return e.Next(ctx, req)
}
// CancelDDLJobsExec represents a cancel DDL jobs executor.
type CancelDDLJobsExec struct {
baseExecutor
@ -559,7 +569,7 @@ func (e *CheckIndexExec) Next(ctx context.Context, req *chunk.RecordBatch) error
}
chk := newFirstChunk(e.src)
for {
err := e.src.Next(ctx, chunk.NewRecordBatch(chk))
err := Next(ctx, e.src, chunk.NewRecordBatch(chk))
if err != nil {
return err
}
@ -668,7 +678,7 @@ func (e *SelectLockExec) Next(ctx context.Context, req *chunk.RecordBatch) error
}
req.GrowAndReset(e.maxChunkSize)
err := e.children[0].Next(ctx, req)
err := Next(ctx, e.children[0], req)
if err != nil {
return err
}
@ -728,7 +738,7 @@ func (e *LimitExec) Next(ctx context.Context, req *chunk.RecordBatch) error {
for !e.meetFirstBatch {
// transfer req's requiredRows to childResult and then adjust it in childResult
e.childResult = e.childResult.SetRequiredRows(req.RequiredRows(), e.maxChunkSize)
err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.adjustRequiredRows(e.childResult)))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.adjustRequiredRows(e.childResult)))
if err != nil {
return err
}
@ -753,7 +763,7 @@ func (e *LimitExec) Next(ctx context.Context, req *chunk.RecordBatch) error {
e.cursor += batchSize
}
e.adjustRequiredRows(req.Chunk)
err := e.children[0].Next(ctx, req)
err := Next(ctx, e.children[0], req)
if err != nil {
return err
}
@ -823,7 +833,7 @@ func init() {
}
chk := newFirstChunk(exec)
for {
err = exec.Next(ctx, chunk.NewRecordBatch(chk))
err = Next(ctx, exec, chunk.NewRecordBatch(chk))
if err != nil {
return rows, err
}
@ -940,7 +950,7 @@ func (e *SelectionExec) Next(ctx context.Context, req *chunk.RecordBatch) error
}
req.AppendRow(e.inputRow)
}
err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult))
if err != nil {
return err
}
@ -972,7 +982,7 @@ func (e *SelectionExec) unBatchedNext(ctx context.Context, chk *chunk.Chunk) err
return nil
}
}
err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult))
if err != nil {
return err
}
@ -1120,7 +1130,7 @@ func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.RecordBatch) error
return nil
}
e.evaluated = true
err := e.children[0].Next(ctx, req)
err := Next(ctx, e.children[0], req)
if err != nil {
return err
}
@ -1135,7 +1145,7 @@ func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.RecordBatch) error
}
childChunk := newFirstChunk(e.children[0])
err = e.children[0].Next(ctx, chunk.NewRecordBatch(childChunk))
err = Next(ctx, e.children[0], chunk.NewRecordBatch(childChunk))
if err != nil {
return err
}
@ -1246,7 +1256,7 @@ func (e *UnionExec) resultPuller(ctx context.Context, childID int) {
return
case result.chk = <-e.resourcePools[childID]:
}
result.err = e.children[childID].Next(ctx, chunk.NewRecordBatch(result.chk))
result.err = Next(ctx, e.children[childID], chunk.NewRecordBatch(result.chk))
if result.err == nil && result.chk.NumRows() == 0 {
return
}

View File

@ -386,7 +386,7 @@ func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) {
task.memTracker.Consume(task.outerResult.MemoryUsage())
for !task.outerResult.IsFull() {
err := ow.executor.Next(ctx, chunk.NewRecordBatch(ow.executorChk))
err := Next(ctx, ow.executor, chunk.NewRecordBatch(ow.executorChk))
if err != nil {
return task, err
}
@ -586,7 +586,7 @@ func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTa
innerResult.GetMemTracker().SetLabel(innerResultLabel)
innerResult.GetMemTracker().AttachTo(task.memTracker)
for {
err := innerExec.Next(ctx, chunk.NewRecordBatch(iw.executorChk))
err := Next(ctx, innerExec, chunk.NewRecordBatch(iw.executorChk))
if err != nil {
return err
}

View File

@ -202,6 +202,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) {
if e.finished.Load().(bool) {
return
}
var outerResource *outerChkResource
var ok bool
select {
@ -217,7 +218,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) {
required := int(atomic.LoadInt64(&e.requiredRows))
outerResult.SetRequiredRows(required, e.maxChunkSize)
}
err := e.outerExec.Next(ctx, chunk.NewRecordBatch(outerResult))
err := Next(ctx, e.outerExec, chunk.NewRecordBatch(outerResult))
if err != nil {
e.joinResultCh <- &hashjoinWorkerResult{
err: err,
@ -244,6 +245,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) {
if outerResult.NumRows() == 0 {
return
}
outerResource.dest <- outerResult
}
}
@ -276,8 +278,9 @@ func (e *HashJoinExec) fetchInnerRows(ctx context.Context) error {
if e.finished.Load().(bool) {
return nil
}
chk := newFirstChunk(e.children[e.innerIdx])
err = e.innerExec.Next(ctx, chunk.NewRecordBatch(chk))
err = Next(ctx, e.innerExec, chunk.NewRecordBatch(chk))
if err != nil || chk.NumRows() == 0 {
return err
}
@ -512,6 +515,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.RecordBatch) (err er
if e.joinResultCh == nil {
return nil
}
result, ok := <-e.joinResultCh
if !ok {
return nil
@ -642,7 +646,7 @@ func (e *NestedLoopApplyExec) fetchSelectedOuterRow(ctx context.Context, chk *ch
outerIter := chunk.NewIterator4Chunk(e.outerChunk)
for {
if e.outerChunkCursor >= e.outerChunk.NumRows() {
err := e.outerExec.Next(ctx, chunk.NewRecordBatch(e.outerChunk))
err := Next(ctx, e.outerExec, chunk.NewRecordBatch(e.outerChunk))
if err != nil {
return nil, err
}
@ -679,7 +683,7 @@ func (e *NestedLoopApplyExec) fetchAllInners(ctx context.Context) error {
e.innerList.Reset()
innerIter := chunk.NewIterator4Chunk(e.innerChunk)
for {
err := e.innerExec.Next(ctx, chunk.NewRecordBatch(e.innerChunk))
err := Next(ctx, e.innerExec, chunk.NewRecordBatch(e.innerChunk))
if err != nil {
return err
}

View File

@ -142,7 +142,7 @@ func (t *mergeJoinInnerTable) nextRow() (chunk.Row, error) {
if t.curRow == t.curIter.End() {
t.reallocReaderResult()
oldMemUsage := t.curResult.MemoryUsage()
err := t.reader.Next(t.ctx, chunk.NewRecordBatch(t.curResult))
err := Next(t.ctx, t.reader, chunk.NewRecordBatch(t.curResult))
// error happens or no more data.
if err != nil || t.curResult.NumRows() == 0 {
t.curRow = t.curIter.End()
@ -389,7 +389,7 @@ func (e *MergeJoinExec) fetchNextOuterRows(ctx context.Context, requiredRows int
e.outerTable.chk.SetRequiredRows(requiredRows, e.maxChunkSize)
}
err = e.outerTable.reader.Next(ctx, chunk.NewRecordBatch(e.outerTable.chk))
err = Next(ctx, e.outerTable.reader, chunk.NewRecordBatch(e.outerTable.chk))
if err != nil {
return err
}

View File

@ -179,7 +179,7 @@ func (e *ProjectionExec) isUnparallelExec() bool {
func (e *ProjectionExec) unParallelExecute(ctx context.Context, chk *chunk.Chunk) error {
// transmit the requiredRows
e.childResult.SetRequiredRows(chk.RequiredRows(), e.maxChunkSize)
err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult))
if err != nil {
return err
}
@ -312,7 +312,7 @@ func (f *projectionInputFetcher) run(ctx context.Context) {
requiredRows := atomic.LoadInt64(&f.proj.parentReqRows)
input.chk.SetRequiredRows(int(requiredRows), f.proj.maxChunkSize)
err := f.child.Next(ctx, chunk.NewRecordBatch(input.chk))
err := Next(ctx, f.child, chunk.NewRecordBatch(input.chk))
if err != nil || input.chk.NumRows() == 0 {
output.done <- err
return

View File

@ -111,7 +111,7 @@ func (e *SortExec) fetchRowChunks(ctx context.Context) error {
e.rowChunks.GetMemTracker().SetLabel(rowChunksLabel)
for {
chk := newFirstChunk(e.children[0])
err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk))
if err != nil {
return err
}
@ -282,7 +282,7 @@ func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error {
srcChk := newFirstChunk(e.children[0])
// adjust required rows by total limit
srcChk.SetRequiredRows(int(e.totalLimit-uint64(e.rowChunks.Len())), e.maxChunkSize)
err := e.children[0].Next(ctx, chunk.NewRecordBatch(srcChk))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(srcChk))
if err != nil {
return err
}
@ -307,7 +307,7 @@ func (e *TopNExec) executeTopN(ctx context.Context) error {
}
childRowChk := newFirstChunk(e.children[0])
for {
err := e.children[0].Next(ctx, chunk.NewRecordBatch(childRowChk))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(childRowChk))
if err != nil {
return err
}

View File

@ -199,7 +199,7 @@ func (us *UnionScanExec) getSnapshotRow(ctx context.Context) ([]types.Datum, err
us.cursor4SnapshotRows = 0
us.snapshotRows = us.snapshotRows[:0]
for len(us.snapshotRows) == 0 {
err = us.children[0].Next(ctx, chunk.NewRecordBatch(us.snapshotChunkBuffer))
err = Next(ctx, us.children[0], chunk.NewRecordBatch(us.snapshotChunkBuffer))
if err != nil || us.snapshotChunkBuffer.NumRows() == 0 {
return nil, err
}

View File

@ -181,7 +181,7 @@ func (e *UpdateExec) fetchChunkRows(ctx context.Context) error {
chk := newFirstChunk(e.children[0])
e.evalBuffer = chunk.MutRowFromTypes(fields)
for {
err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk))
err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk))
if err != nil {
return err
}

View File

@ -131,7 +131,7 @@ func (e *WindowExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk
}
childResult := newFirstChunk(e.children[0])
err = e.children[0].Next(ctx, &chunk.RecordBatch{Chunk: childResult})
err = Next(ctx, e.children[0], &chunk.RecordBatch{Chunk: childResult})
if err != nil {
return errors.Trace(err)
}

View File

@ -45,7 +45,6 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@ -150,13 +149,6 @@ type clientConn struct {
peerHost string // peer host
peerPort string // peer port
lastCode uint16 // last error code
// mu is used for cancelling the execution of current transaction.
mu struct {
sync.RWMutex
cancelFunc context.CancelFunc
resultSets []ResultSet
}
}
func (cc *clientConn) String() string {
@ -847,11 +839,6 @@ func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) {
func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
span := opentracing.StartSpan("server.dispatch")
ctx1, cancelFunc := context.WithCancel(ctx)
cc.mu.Lock()
cc.mu.cancelFunc = cancelFunc
cc.mu.Unlock()
t := time.Now()
cmd := data[0]
data = data[1:]
@ -863,6 +850,8 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
span.Finish()
}()
vars := cc.ctx.GetSessionVars()
atomic.StoreUint32(&vars.Killed, 0)
if cmd < mysql.ComEnd {
cc.ctx.SetCommandValue(cmd)
}
@ -893,11 +882,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
data = data[:len(data)-1]
dataStr = string(hack.String(data))
}
return cc.handleQuery(ctx1, dataStr)
return cc.handleQuery(ctx, dataStr)
case mysql.ComPing:
return cc.writeOK()
case mysql.ComInitDB:
if err := cc.useDB(ctx1, dataStr); err != nil {
if err := cc.useDB(ctx, dataStr); err != nil {
return err
}
return cc.writeOK()
@ -906,9 +895,9 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
case mysql.ComStmtPrepare:
return cc.handleStmtPrepare(dataStr)
case mysql.ComStmtExecute:
return cc.handleStmtExecute(ctx1, data)
return cc.handleStmtExecute(ctx, data)
case mysql.ComStmtFetch:
return cc.handleStmtFetch(ctx1, data)
return cc.handleStmtFetch(ctx, data)
case mysql.ComStmtClose:
return cc.handleStmtClose(data)
case mysql.ComStmtSendLongData:
@ -918,7 +907,7 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
case mysql.ComSetOption:
return cc.handleSetOption(data)
case mysql.ComChangeUser:
return cc.handleChangeUser(ctx1, data)
return cc.handleChangeUser(ctx, data)
default:
return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", cmd)
}
@ -1171,15 +1160,11 @@ func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) {
metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc()
return err
}
cc.mu.Lock()
cc.mu.resultSets = rs
status := atomic.LoadInt32(&cc.status)
if status == connStatusShutdown || status == connStatusWaitShutdown {
cc.mu.Unlock()
killConn(cc)
return errors.New("killed by another connection")
}
cc.mu.Unlock()
if rs != nil {
if len(rs) == 1 {
err = cc.writeResultset(ctx, rs[0], false, 0, 0)

View File

@ -530,19 +530,8 @@ func (s *Server) Kill(connectionID uint64, query bool) {
}
func killConn(conn *clientConn) {
conn.mu.RLock()
resultSets := conn.mu.resultSets
cancelFunc := conn.mu.cancelFunc
conn.mu.RUnlock()
for _, resultSet := range resultSets {
// resultSet.Close() is reentrant so it's safe to kill a same connID multiple times
if err := resultSet.Close(); err != nil {
logutil.Logger(context.Background()).Error("close result set error", zap.Uint32("connID", conn.connectionID), zap.Error(err))
}
}
if cancelFunc != nil {
cancelFunc()
}
sessVars := conn.ctx.GetSessionVars()
atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1)
}
// KillAllConnections kills all connections when server is not gracefully shutdown.

View File

@ -379,6 +379,9 @@ type SessionVars struct {
// LowResolutionTSO is used for reading data with low resolution TSO which is updated once every two seconds.
LowResolutionTSO bool
// Killed is a flag to indicate that this query is killed.
Killed uint32
}
// ConnectionInfo present connection used by audit.