server,session: code refactor for multiple statements in one query (#16056)

This commit is contained in:
tiancaiamao
2020-04-14 20:16:43 +08:00
committed by GitHub
parent c66320c464
commit dc139aedb4
7 changed files with 234 additions and 127 deletions

View File

@ -54,6 +54,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
@ -931,7 +932,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
func (cc *clientConn) useDB(ctx context.Context, db string) (err error) {
// if input is "use `SELECT`", mysql client just send "SELECT"
// so we add `` around db.
_, err = cc.ctx.Execute(ctx, "use `"+db+"`")
stmts, err := cc.ctx.Parse(ctx, "use `"+db+"`")
if err != nil {
return err
}
_, err = cc.ctx.ExecuteStmt(ctx, stmts[0])
if err != nil {
return err
}
@ -1263,55 +1268,84 @@ func (cc *clientConn) handleIndexAdvise(ctx context.Context, indexAdviseInfo *ex
// There is a special query `load data` that does not return result, which is handled differently.
// Query `load stats` does not return result either.
func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) {
rss, err := cc.ctx.Execute(ctx, sql)
stmts, err := cc.ctx.Parse(ctx, sql)
if err != nil {
metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc()
return err
}
status := atomic.LoadInt32(&cc.status)
if rss != nil && (status == connStatusShutdown || status == connStatusWaitShutdown) {
for _, rs := range rss {
terror.Call(rs.Close)
for i, stmt := range stmts {
if err = cc.handleStmt(ctx, stmt, i == len(stmts)-1); err != nil {
break
}
return executor.ErrQueryInterrupted
}
if rss != nil {
if len(rss) == 1 {
err = cc.writeResultset(ctx, rss[0], false, 0, 0)
} else {
err = cc.writeMultiResultset(ctx, rss, false)
}
} else {
loadDataInfo := cc.ctx.Value(executor.LoadDataVarKey)
if loadDataInfo != nil {
defer cc.ctx.SetValue(executor.LoadDataVarKey, nil)
if err = cc.handleLoadData(ctx, loadDataInfo.(*executor.LoadDataInfo)); err != nil {
return err
}
}
loadStats := cc.ctx.Value(executor.LoadStatsVarKey)
if loadStats != nil {
defer cc.ctx.SetValue(executor.LoadStatsVarKey, nil)
if err = cc.handleLoadStats(ctx, loadStats.(*executor.LoadStatsInfo)); err != nil {
return err
}
}
indexAdvise := cc.ctx.Value(executor.IndexAdviseVarKey)
if indexAdvise != nil {
defer cc.ctx.SetValue(executor.IndexAdviseVarKey, nil)
err = cc.handleIndexAdvise(ctx, indexAdvise.(*executor.IndexAdviseInfo))
if err != nil {
return err
}
}
err = cc.writeOK()
if err != nil {
metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc()
}
return err
}
func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, lastStmt bool) error {
rs, err := cc.ctx.ExecuteStmt(ctx, stmt)
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil {
return err
}
status := cc.ctx.Status()
if !lastStmt {
status |= mysql.ServerMoreResultsExists
}
if rs != nil {
connStatus := atomic.LoadInt32(&cc.status)
if connStatus == connStatusShutdown || connStatus == connStatusWaitShutdown {
return executor.ErrQueryInterrupted
}
err = cc.writeResultset(ctx, rs, false, status, 0)
if err != nil {
return err
}
} else {
err = cc.handleQuerySpecial(ctx, status)
if err != nil {
return err
}
}
return nil
}
func (cc *clientConn) handleQuerySpecial(ctx context.Context, status uint16) error {
loadDataInfo := cc.ctx.Value(executor.LoadDataVarKey)
if loadDataInfo != nil {
defer cc.ctx.SetValue(executor.LoadDataVarKey, nil)
if err := cc.handleLoadData(ctx, loadDataInfo.(*executor.LoadDataInfo)); err != nil {
return err
}
}
loadStats := cc.ctx.Value(executor.LoadStatsVarKey)
if loadStats != nil {
defer cc.ctx.SetValue(executor.LoadStatsVarKey, nil)
if err := cc.handleLoadStats(ctx, loadStats.(*executor.LoadStatsInfo)); err != nil {
return err
}
}
indexAdvise := cc.ctx.Value(executor.IndexAdviseVarKey)
if indexAdvise != nil {
defer cc.ctx.SetValue(executor.IndexAdviseVarKey, nil)
if err := cc.handleIndexAdvise(ctx, indexAdvise.(*executor.IndexAdviseInfo)); err != nil {
return err
}
}
return cc.writeOkWith(cc.ctx.LastMessage(), cc.ctx.AffectedRows(), cc.ctx.LastInsertID(), status, cc.ctx.WarningCount())
}
// handleFieldList returns the field list for a table.
// The sql string is composed of a table name and a terminating character \x00.
func (cc *clientConn) handleFieldList(sql string) (err error) {
@ -1344,13 +1378,9 @@ func (cc *clientConn) handleFieldList(sql string) (err error) {
// If binary is true, the data would be encoded in BINARY format.
// serverStatus, a flag bit represents server information.
// fetchSize, the desired number of rows to be fetched each time when client uses cursor.
// resultsets, it's used to support the MULTI_RESULTS capability in mysql protocol.
func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16, fetchSize int) (runErr error) {
defer func() {
// close ResultSet when cursor doesn't exist
if !mysql.HasCursorExistsFlag(serverStatus) {
terror.Call(rs.Close)
}
r := recover()
if r == nil {
return

View File

@ -44,6 +44,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
@ -204,6 +205,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
// explicitly flush columnInfo to client.
return cc.flush()
}
defer terror.Call(rs.Close)
err = cc.writeResultset(ctx, rs, true, 0, 0)
if err != nil {
return errors.Annotate(err, cc.preparedStmt2String(stmtID))

View File

@ -23,7 +23,6 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/domain"
@ -33,7 +32,6 @@ import (
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/util/arena"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/testleak"
)
@ -440,41 +438,27 @@ func (ts *ConnTestSuite) TestConnExecutionTimeout(c *C) {
type mockTiDBCtx struct {
TiDBContext
rs []ResultSet
err error
}
func (c *mockTiDBCtx) Execute(ctx context.Context, sql string) ([]ResultSet, error) {
return c.rs, c.err
}
func (c *mockTiDBCtx) ExecuteInternal(ctx context.Context, sql string) ([]ResultSet, error) {
return c.rs, c.err
}
func (c *mockTiDBCtx) GetSessionVars() *variable.SessionVars {
return &variable.SessionVars{}
}
type mockRecordSet struct{}
func (m mockRecordSet) Fields() []*ast.ResultField { return nil }
func (m mockRecordSet) Next(ctx context.Context, req *chunk.Chunk) error { return nil }
func (m mockRecordSet) NewChunk() *chunk.Chunk { return nil }
func (m mockRecordSet) Close() error { return nil }
func (ts *ConnTestSuite) TestShutDown(c *C) {
cc := &clientConn{}
rs := &tidbResultSet{recordSet: mockRecordSet{}}
se, err := session.CreateSession4Test(ts.store)
c.Assert(err, IsNil)
// mock delay response
cc.ctx = &mockTiDBCtx{rs: []ResultSet{rs}, err: nil}
cc.ctx = &mockTiDBCtx{
TiDBContext: TiDBContext{session: se},
err: nil,
}
// set killed flag
cc.status = connStatusShutdown
// assert ErrQueryInterrupted
err := cc.handleQuery(context.Background(), "dummy")
err = cc.handleQuery(context.Background(), "select 1")
c.Assert(err, Equals, executor.ErrQueryInterrupted)
c.Assert(rs.closed, Equals, int32(1))
}
func (ts *ConnTestSuite) TestShutdownOrNotify(c *C) {

View File

@ -19,6 +19,7 @@ import (
"fmt"
"time"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/auth"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
@ -66,11 +67,11 @@ type QueryCtx interface {
// CurrentDB returns current DB.
CurrentDB() string
// Execute executes a SQL statement.
Execute(ctx context.Context, sql string) ([]ResultSet, error)
// ExecuteStmt executes a SQL statement.
ExecuteStmt(context.Context, ast.StmtNode) (ResultSet, error)
// ExecuteInternal executes a internal SQL statement.
ExecuteInternal(ctx context.Context, sql string) ([]ResultSet, error)
// Parse parses a SQL to statement node.
Parse(ctx context.Context, sql string) ([]ast.StmtNode, error)
// SetClientCapability sets client capability flags
SetClientCapability(uint32)

View File

@ -243,40 +243,23 @@ func (tc *TiDBContext) WarningCount() uint16 {
return tc.session.GetSessionVars().StmtCtx.WarningCount()
}
// Execute implements QueryCtx Execute method.
func (tc *TiDBContext) Execute(ctx context.Context, sql string) (rs []ResultSet, err error) {
rsList, err := tc.session.Execute(ctx, sql)
// ExecuteStmt implements QueryCtx interface.
func (tc *TiDBContext) ExecuteStmt(ctx context.Context, stmt ast.StmtNode) (ResultSet, error) {
rs, err := tc.session.ExecuteStmt(ctx, stmt)
if err != nil {
return
return nil, err
}
if len(rsList) == 0 { // result ok
return
if rs == nil {
return nil, nil
}
rs = make([]ResultSet, len(rsList))
for i := 0; i < len(rsList); i++ {
rs[i] = &tidbResultSet{
recordSet: rsList[i],
}
}
return
return &tidbResultSet{
recordSet: rs,
}, nil
}
// ExecuteInternal implements QueryCtx ExecuteInternal method.
func (tc *TiDBContext) ExecuteInternal(ctx context.Context, sql string) (rs []ResultSet, err error) {
rsList, err := tc.session.ExecuteInternal(ctx, sql)
if err != nil {
return
}
if len(rsList) == 0 { // result ok
return
}
rs = make([]ResultSet, len(rsList))
for i := 0; i < len(rsList); i++ {
rs[i] = &tidbResultSet{
recordSet: rsList[i],
}
}
return
// Parse implements QueryCtx interface.
func (tc *TiDBContext) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) {
return tc.session.Parse(ctx, sql)
}
// SetSessionManager implements the QueryCtx interface.

View File

@ -729,7 +729,7 @@ func (ts *tidbTestSuite) TestCreateTableFlen(c *C) {
// issue #4540
qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil)
c.Assert(err, IsNil)
_, err = qctx.Execute(context.Background(), "use test;")
_, err = Execute(context.Background(), qctx, "use test;")
c.Assert(err, IsNil)
ctx := context.Background()
@ -762,44 +762,55 @@ func (ts *tidbTestSuite) TestCreateTableFlen(c *C) {
"`z` decimal(20, 4)," +
"PRIMARY KEY (`a`)" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin"
_, err = qctx.Execute(ctx, testSQL)
_, err = Execute(ctx, qctx, testSQL)
c.Assert(err, IsNil)
rs, err := qctx.Execute(ctx, "show create table t1")
rs, err := Execute(ctx, qctx, "show create table t1")
c.Assert(err, IsNil)
req := rs[0].NewChunk()
err = rs[0].Next(ctx, req)
req := rs.NewChunk()
err = rs.Next(ctx, req)
c.Assert(err, IsNil)
cols := rs[0].Columns()
cols := rs.Columns()
c.Assert(err, IsNil)
c.Assert(len(cols), Equals, 2)
c.Assert(int(cols[0].ColumnLength), Equals, 5*tmysql.MaxBytesOfCharacter)
c.Assert(int(cols[1].ColumnLength), Equals, len(req.GetRow(0).GetString(1))*tmysql.MaxBytesOfCharacter)
// for issue#5246
rs, err = qctx.Execute(ctx, "select y, z from t1")
rs, err = Execute(ctx, qctx, "select y, z from t1")
c.Assert(err, IsNil)
cols = rs[0].Columns()
cols = rs.Columns()
c.Assert(len(cols), Equals, 2)
c.Assert(int(cols[0].ColumnLength), Equals, 21)
c.Assert(int(cols[1].ColumnLength), Equals, 22)
}
func Execute(ctx context.Context, qc QueryCtx, sql string) (ResultSet, error) {
stmts, err := qc.Parse(ctx, sql)
if err != nil {
return nil, err
}
if len(stmts) != 1 {
panic("wrong input for Execute: " + sql)
}
return qc.ExecuteStmt(ctx, stmts[0])
}
func (ts *tidbTestSuite) TestShowTablesFlen(c *C) {
qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil)
c.Assert(err, IsNil)
_, err = qctx.Execute(context.Background(), "use test;")
ctx := context.Background()
_, err = Execute(ctx, qctx, "use test;")
c.Assert(err, IsNil)
ctx := context.Background()
testSQL := "create table abcdefghijklmnopqrstuvwxyz (i int)"
_, err = qctx.Execute(ctx, testSQL)
_, err = Execute(ctx, qctx, testSQL)
c.Assert(err, IsNil)
rs, err := qctx.Execute(ctx, "show tables")
rs, err := Execute(ctx, qctx, "show tables")
c.Assert(err, IsNil)
req := rs[0].NewChunk()
err = rs[0].Next(ctx, req)
req := rs.NewChunk()
err = rs.Next(ctx, req)
c.Assert(err, IsNil)
cols := rs[0].Columns()
cols := rs.Columns()
c.Assert(err, IsNil)
c.Assert(len(cols), Equals, 1)
c.Assert(int(cols[0].ColumnLength), Equals, 26*tmysql.MaxBytesOfCharacter)
@ -815,7 +826,7 @@ func checkColNames(c *C, columns []*ColumnInfo, names ...string) {
func (ts *tidbTestSuite) TestFieldList(c *C) {
qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil)
c.Assert(err, IsNil)
_, err = qctx.Execute(context.Background(), "use test;")
_, err = Execute(context.Background(), qctx, "use test;")
c.Assert(err, IsNil)
ctx := context.Background()
@ -840,7 +851,7 @@ func (ts *tidbTestSuite) TestFieldList(c *C) {
c_json JSON,
c_year year
)`
_, err = qctx.Execute(ctx, testSQL)
_, err = Execute(ctx, qctx, testSQL)
c.Assert(err, IsNil)
colInfos, err := qctx.FieldList("t")
c.Assert(err, IsNil)
@ -878,15 +889,15 @@ func (ts *tidbTestSuite) TestFieldList(c *C) {
tooLongColumnAsName := "COALESCE(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0)"
columnAsName := tooLongColumnAsName[:tmysql.MaxAliasIdentifierLen]
rs, err := qctx.Execute(ctx, "select "+tooLongColumnAsName)
rs, err := Execute(ctx, qctx, "select "+tooLongColumnAsName)
c.Assert(err, IsNil)
cols := rs[0].Columns()
cols := rs.Columns()
c.Assert(cols[0].OrgName, Equals, tooLongColumnAsName)
c.Assert(cols[0].Name, Equals, columnAsName)
rs, err = qctx.Execute(ctx, "select c_bit as '"+tooLongColumnAsName+"' from t")
rs, err = Execute(ctx, qctx, "select c_bit as '"+tooLongColumnAsName+"' from t")
c.Assert(err, IsNil)
cols = rs[0].Columns()
cols = rs.Columns()
c.Assert(cols[0].OrgName, Equals, "c_bit")
c.Assert(cols[0].Name, Equals, columnAsName)
}
@ -902,9 +913,9 @@ func (ts *tidbTestSuite) TestNullFlag(c *C) {
c.Assert(err, IsNil)
ctx := context.Background()
rs, err := qctx.Execute(ctx, "select 1")
rs, err := Execute(ctx, qctx, "select 1")
c.Assert(err, IsNil)
cols := rs[0].Columns()
cols := rs.Columns()
c.Assert(len(cols), Equals, 1)
expectFlag := uint16(tmysql.NotNullFlag | tmysql.BinaryFlag)
c.Assert(dumpFlag(cols[0].Type, cols[0].Flag), Equals, expectFlag)

View File

@ -95,13 +95,16 @@ var (
// Session context, it is consistent with the lifecycle of a client connection.
type Session interface {
sessionctx.Context
Status() uint16 // Flag of current status, such as autocommit.
LastInsertID() uint64 // LastInsertID is the last inserted auto_increment ID.
LastMessage() string // LastMessage is the info message that may be generated by last command
AffectedRows() uint64 // Affected rows by latest executed stmt.
Status() uint16 // Flag of current status, such as autocommit.
LastInsertID() uint64 // LastInsertID is the last inserted auto_increment ID.
LastMessage() string // LastMessage is the info message that may be generated by last command
AffectedRows() uint64 // Affected rows by latest executed stmt.
// Execute is deprecated, use ExecuteStmt() instead.
Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement.
ExecuteInternal(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a internal sql statement.
String() string // String is used to debug.
ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error)
Parse(ctx context.Context, sql string) ([]ast.StmtNode, error)
String() string // String is used to debug.
CommitTxn(context.Context) error
RollbackTxn(context.Context)
// PrepareStmt executes prepare statement in binary protocol.
@ -1088,6 +1091,99 @@ func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec
return
}
// Parse parses a query string to raw ast.StmtNode.
func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) {
charsetInfo, collation := s.sessionVars.GetCharsetInfo()
parseStartTime := time.Now()
stmts, warns, err := s.ParseSQL(ctx, sql, charsetInfo, collation)
if err != nil {
s.rollbackOnError(ctx)
// Only print log message when this SQL is from the user.
// Mute the warning for internal SQLs.
if !s.sessionVars.InRestrictedSQL {
logutil.Logger(ctx).Warn("parse SQL failed", zap.Error(err), zap.String("SQL", sql))
}
return nil, util.SyntaxError(err)
}
durParse := time.Since(parseStartTime)
s.GetSessionVars().DurationParse = durParse
isInternal := s.isInternal()
if isInternal {
sessionExecuteParseDurationInternal.Observe(durParse.Seconds())
} else {
sessionExecuteParseDurationGeneral.Observe(durParse.Seconds())
}
for _, warn := range warns {
s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}
return stmts, nil
}
func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan("session.ExecuteStmt", opentracing.ChildOf(span.Context()))
defer span1.Finish()
ctx = opentracing.ContextWithSpan(ctx, span1)
}
s.PrepareTxnCtx(ctx)
err := s.loadCommonGlobalVariablesIfNeeded()
if err != nil {
return nil, err
}
s.sessionVars.StartTime = time.Now()
// Some executions are done in compile stage, so we reset them before compile.
if err := executor.ResetContextOfStmt(s, stmtNode); err != nil {
return nil, err
}
// Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt).
compiler := executor.Compiler{Ctx: s}
stmt, err := compiler.Compile(ctx, stmtNode)
if err != nil {
s.rollbackOnError(ctx)
// Only print log message when this SQL is from the user.
// Mute the warning for internal SQLs.
if !s.sessionVars.InRestrictedSQL {
logutil.Logger(ctx).Warn("compile SQL failed", zap.Error(err), zap.String("SQL", stmtNode.Text()))
}
return nil, err
}
durCompile := time.Since(s.sessionVars.StartTime)
s.GetSessionVars().DurationCompile = durCompile
if s.isInternal() {
sessionExecuteCompileDurationInternal.Observe(durCompile.Seconds())
} else {
sessionExecuteCompileDurationGeneral.Observe(durCompile.Seconds())
}
s.currentPlan = stmt.Plan
// Execute the physical plan.
logStmt(stmtNode, s.sessionVars)
startTime := time.Now()
recordSet, err := runStmt(ctx, s, stmt)
if err != nil {
if !kv.ErrKeyExists.Equal(err) {
logutil.Logger(ctx).Warn("run statement failed",
zap.Int64("schemaVersion", s.sessionVars.TxnCtx.SchemaVersion),
zap.Error(err),
zap.String("session", s.String()))
}
return nil, err
}
if s.isInternal() {
sessionExecuteRunDurationInternal.Observe(time.Since(startTime).Seconds())
} else {
sessionExecuteRunDurationGeneral.Observe(time.Since(startTime).Seconds())
}
return recordSet, nil
}
func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) {
s.PrepareTxnCtx(ctx)
connID := s.sessionVars.ConnectionID