diff --git a/server/conn.go b/server/conn.go index eeac16ea40..fc4a1e3204 100644 --- a/server/conn.go +++ b/server/conn.go @@ -54,6 +54,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/parser" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -931,7 +932,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { func (cc *clientConn) useDB(ctx context.Context, db string) (err error) { // if input is "use `SELECT`", mysql client just send "SELECT" // so we add `` around db. - _, err = cc.ctx.Execute(ctx, "use `"+db+"`") + stmts, err := cc.ctx.Parse(ctx, "use `"+db+"`") + if err != nil { + return err + } + _, err = cc.ctx.ExecuteStmt(ctx, stmts[0]) if err != nil { return err } @@ -1263,55 +1268,84 @@ func (cc *clientConn) handleIndexAdvise(ctx context.Context, indexAdviseInfo *ex // There is a special query `load data` that does not return result, which is handled differently. // Query `load stats` does not return result either. func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) { - rss, err := cc.ctx.Execute(ctx, sql) + stmts, err := cc.ctx.Parse(ctx, sql) if err != nil { metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc() return err } - status := atomic.LoadInt32(&cc.status) - if rss != nil && (status == connStatusShutdown || status == connStatusWaitShutdown) { - for _, rs := range rss { - terror.Call(rs.Close) + + for i, stmt := range stmts { + if err = cc.handleStmt(ctx, stmt, i == len(stmts)-1); err != nil { + break } - return executor.ErrQueryInterrupted } - if rss != nil { - if len(rss) == 1 { - err = cc.writeResultset(ctx, rss[0], false, 0, 0) - } else { - err = cc.writeMultiResultset(ctx, rss, false) - } - } else { - loadDataInfo := cc.ctx.Value(executor.LoadDataVarKey) - if loadDataInfo != nil { - defer cc.ctx.SetValue(executor.LoadDataVarKey, nil) - if err = cc.handleLoadData(ctx, loadDataInfo.(*executor.LoadDataInfo)); err != nil { - return err - } - } - loadStats := cc.ctx.Value(executor.LoadStatsVarKey) - if loadStats != nil { - defer cc.ctx.SetValue(executor.LoadStatsVarKey, nil) - if err = cc.handleLoadStats(ctx, loadStats.(*executor.LoadStatsInfo)); err != nil { - return err - } - } - - indexAdvise := cc.ctx.Value(executor.IndexAdviseVarKey) - if indexAdvise != nil { - defer cc.ctx.SetValue(executor.IndexAdviseVarKey, nil) - err = cc.handleIndexAdvise(ctx, indexAdvise.(*executor.IndexAdviseInfo)) - if err != nil { - return err - } - } - - err = cc.writeOK() + if err != nil { + metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc() } return err } +func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, lastStmt bool) error { + rs, err := cc.ctx.ExecuteStmt(ctx, stmt) + if rs != nil { + defer terror.Call(rs.Close) + } + if err != nil { + return err + } + + status := cc.ctx.Status() + if !lastStmt { + status |= mysql.ServerMoreResultsExists + } + + if rs != nil { + connStatus := atomic.LoadInt32(&cc.status) + if connStatus == connStatusShutdown || connStatus == connStatusWaitShutdown { + return executor.ErrQueryInterrupted + } + + err = cc.writeResultset(ctx, rs, false, status, 0) + if err != nil { + return err + } + } else { + err = cc.handleQuerySpecial(ctx, status) + if err != nil { + return err + } + } + return nil +} + +func (cc *clientConn) handleQuerySpecial(ctx context.Context, status uint16) error { + loadDataInfo := cc.ctx.Value(executor.LoadDataVarKey) + if loadDataInfo != nil { + defer cc.ctx.SetValue(executor.LoadDataVarKey, nil) + if err := cc.handleLoadData(ctx, loadDataInfo.(*executor.LoadDataInfo)); err != nil { + return err + } + } + + loadStats := cc.ctx.Value(executor.LoadStatsVarKey) + if loadStats != nil { + defer cc.ctx.SetValue(executor.LoadStatsVarKey, nil) + if err := cc.handleLoadStats(ctx, loadStats.(*executor.LoadStatsInfo)); err != nil { + return err + } + } + + indexAdvise := cc.ctx.Value(executor.IndexAdviseVarKey) + if indexAdvise != nil { + defer cc.ctx.SetValue(executor.IndexAdviseVarKey, nil) + if err := cc.handleIndexAdvise(ctx, indexAdvise.(*executor.IndexAdviseInfo)); err != nil { + return err + } + } + return cc.writeOkWith(cc.ctx.LastMessage(), cc.ctx.AffectedRows(), cc.ctx.LastInsertID(), status, cc.ctx.WarningCount()) +} + // handleFieldList returns the field list for a table. // The sql string is composed of a table name and a terminating character \x00. func (cc *clientConn) handleFieldList(sql string) (err error) { @@ -1344,13 +1378,9 @@ func (cc *clientConn) handleFieldList(sql string) (err error) { // If binary is true, the data would be encoded in BINARY format. // serverStatus, a flag bit represents server information. // fetchSize, the desired number of rows to be fetched each time when client uses cursor. -// resultsets, it's used to support the MULTI_RESULTS capability in mysql protocol. func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16, fetchSize int) (runErr error) { defer func() { // close ResultSet when cursor doesn't exist - if !mysql.HasCursorExistsFlag(serverStatus) { - terror.Call(rs.Close) - } r := recover() if r == nil { return diff --git a/server/conn_stmt.go b/server/conn_stmt.go index dfd1d4331e..e1aea78688 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -44,6 +44,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/terror" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" @@ -204,6 +205,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e // explicitly flush columnInfo to client. return cc.flush() } + defer terror.Call(rs.Close) err = cc.writeResultset(ctx, rs, true, 0, 0) if err != nil { return errors.Annotate(err, cc.preparedStmt2String(stmtID)) diff --git a/server/conn_test.go b/server/conn_test.go index 4997020c7b..48d208669f 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -23,7 +23,6 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/failpoint" - "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" @@ -33,7 +32,6 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/util/arena" - "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/testleak" ) @@ -440,41 +438,27 @@ func (ts *ConnTestSuite) TestConnExecutionTimeout(c *C) { type mockTiDBCtx struct { TiDBContext - rs []ResultSet err error } -func (c *mockTiDBCtx) Execute(ctx context.Context, sql string) ([]ResultSet, error) { - return c.rs, c.err -} - -func (c *mockTiDBCtx) ExecuteInternal(ctx context.Context, sql string) ([]ResultSet, error) { - return c.rs, c.err -} - func (c *mockTiDBCtx) GetSessionVars() *variable.SessionVars { return &variable.SessionVars{} } -type mockRecordSet struct{} - -func (m mockRecordSet) Fields() []*ast.ResultField { return nil } -func (m mockRecordSet) Next(ctx context.Context, req *chunk.Chunk) error { return nil } -func (m mockRecordSet) NewChunk() *chunk.Chunk { return nil } -func (m mockRecordSet) Close() error { return nil } - func (ts *ConnTestSuite) TestShutDown(c *C) { cc := &clientConn{} - - rs := &tidbResultSet{recordSet: mockRecordSet{}} + se, err := session.CreateSession4Test(ts.store) + c.Assert(err, IsNil) // mock delay response - cc.ctx = &mockTiDBCtx{rs: []ResultSet{rs}, err: nil} + cc.ctx = &mockTiDBCtx{ + TiDBContext: TiDBContext{session: se}, + err: nil, + } // set killed flag cc.status = connStatusShutdown // assert ErrQueryInterrupted - err := cc.handleQuery(context.Background(), "dummy") + err = cc.handleQuery(context.Background(), "select 1") c.Assert(err, Equals, executor.ErrQueryInterrupted) - c.Assert(rs.closed, Equals, int32(1)) } func (ts *ConnTestSuite) TestShutdownOrNotify(c *C) { diff --git a/server/driver.go b/server/driver.go index f336db8f13..389f9f2700 100644 --- a/server/driver.go +++ b/server/driver.go @@ -19,6 +19,7 @@ import ( "fmt" "time" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/auth" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" @@ -66,11 +67,11 @@ type QueryCtx interface { // CurrentDB returns current DB. CurrentDB() string - // Execute executes a SQL statement. - Execute(ctx context.Context, sql string) ([]ResultSet, error) + // ExecuteStmt executes a SQL statement. + ExecuteStmt(context.Context, ast.StmtNode) (ResultSet, error) - // ExecuteInternal executes a internal SQL statement. - ExecuteInternal(ctx context.Context, sql string) ([]ResultSet, error) + // Parse parses a SQL to statement node. + Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) // SetClientCapability sets client capability flags SetClientCapability(uint32) diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 8413b850db..9857217761 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -243,40 +243,23 @@ func (tc *TiDBContext) WarningCount() uint16 { return tc.session.GetSessionVars().StmtCtx.WarningCount() } -// Execute implements QueryCtx Execute method. -func (tc *TiDBContext) Execute(ctx context.Context, sql string) (rs []ResultSet, err error) { - rsList, err := tc.session.Execute(ctx, sql) +// ExecuteStmt implements QueryCtx interface. +func (tc *TiDBContext) ExecuteStmt(ctx context.Context, stmt ast.StmtNode) (ResultSet, error) { + rs, err := tc.session.ExecuteStmt(ctx, stmt) if err != nil { - return + return nil, err } - if len(rsList) == 0 { // result ok - return + if rs == nil { + return nil, nil } - rs = make([]ResultSet, len(rsList)) - for i := 0; i < len(rsList); i++ { - rs[i] = &tidbResultSet{ - recordSet: rsList[i], - } - } - return + return &tidbResultSet{ + recordSet: rs, + }, nil } -// ExecuteInternal implements QueryCtx ExecuteInternal method. -func (tc *TiDBContext) ExecuteInternal(ctx context.Context, sql string) (rs []ResultSet, err error) { - rsList, err := tc.session.ExecuteInternal(ctx, sql) - if err != nil { - return - } - if len(rsList) == 0 { // result ok - return - } - rs = make([]ResultSet, len(rsList)) - for i := 0; i < len(rsList); i++ { - rs[i] = &tidbResultSet{ - recordSet: rsList[i], - } - } - return +// Parse implements QueryCtx interface. +func (tc *TiDBContext) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) { + return tc.session.Parse(ctx, sql) } // SetSessionManager implements the QueryCtx interface. diff --git a/server/tidb_test.go b/server/tidb_test.go index bd83e12327..0b35f5821d 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -729,7 +729,7 @@ func (ts *tidbTestSuite) TestCreateTableFlen(c *C) { // issue #4540 qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) c.Assert(err, IsNil) - _, err = qctx.Execute(context.Background(), "use test;") + _, err = Execute(context.Background(), qctx, "use test;") c.Assert(err, IsNil) ctx := context.Background() @@ -762,44 +762,55 @@ func (ts *tidbTestSuite) TestCreateTableFlen(c *C) { "`z` decimal(20, 4)," + "PRIMARY KEY (`a`)" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin" - _, err = qctx.Execute(ctx, testSQL) + _, err = Execute(ctx, qctx, testSQL) c.Assert(err, IsNil) - rs, err := qctx.Execute(ctx, "show create table t1") + rs, err := Execute(ctx, qctx, "show create table t1") c.Assert(err, IsNil) - req := rs[0].NewChunk() - err = rs[0].Next(ctx, req) + req := rs.NewChunk() + err = rs.Next(ctx, req) c.Assert(err, IsNil) - cols := rs[0].Columns() + cols := rs.Columns() c.Assert(err, IsNil) c.Assert(len(cols), Equals, 2) c.Assert(int(cols[0].ColumnLength), Equals, 5*tmysql.MaxBytesOfCharacter) c.Assert(int(cols[1].ColumnLength), Equals, len(req.GetRow(0).GetString(1))*tmysql.MaxBytesOfCharacter) // for issue#5246 - rs, err = qctx.Execute(ctx, "select y, z from t1") + rs, err = Execute(ctx, qctx, "select y, z from t1") c.Assert(err, IsNil) - cols = rs[0].Columns() + cols = rs.Columns() c.Assert(len(cols), Equals, 2) c.Assert(int(cols[0].ColumnLength), Equals, 21) c.Assert(int(cols[1].ColumnLength), Equals, 22) } +func Execute(ctx context.Context, qc QueryCtx, sql string) (ResultSet, error) { + stmts, err := qc.Parse(ctx, sql) + if err != nil { + return nil, err + } + if len(stmts) != 1 { + panic("wrong input for Execute: " + sql) + } + return qc.ExecuteStmt(ctx, stmts[0]) +} + func (ts *tidbTestSuite) TestShowTablesFlen(c *C) { qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) c.Assert(err, IsNil) - _, err = qctx.Execute(context.Background(), "use test;") + ctx := context.Background() + _, err = Execute(ctx, qctx, "use test;") c.Assert(err, IsNil) - ctx := context.Background() testSQL := "create table abcdefghijklmnopqrstuvwxyz (i int)" - _, err = qctx.Execute(ctx, testSQL) + _, err = Execute(ctx, qctx, testSQL) c.Assert(err, IsNil) - rs, err := qctx.Execute(ctx, "show tables") + rs, err := Execute(ctx, qctx, "show tables") c.Assert(err, IsNil) - req := rs[0].NewChunk() - err = rs[0].Next(ctx, req) + req := rs.NewChunk() + err = rs.Next(ctx, req) c.Assert(err, IsNil) - cols := rs[0].Columns() + cols := rs.Columns() c.Assert(err, IsNil) c.Assert(len(cols), Equals, 1) c.Assert(int(cols[0].ColumnLength), Equals, 26*tmysql.MaxBytesOfCharacter) @@ -815,7 +826,7 @@ func checkColNames(c *C, columns []*ColumnInfo, names ...string) { func (ts *tidbTestSuite) TestFieldList(c *C) { qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil) c.Assert(err, IsNil) - _, err = qctx.Execute(context.Background(), "use test;") + _, err = Execute(context.Background(), qctx, "use test;") c.Assert(err, IsNil) ctx := context.Background() @@ -840,7 +851,7 @@ func (ts *tidbTestSuite) TestFieldList(c *C) { c_json JSON, c_year year )` - _, err = qctx.Execute(ctx, testSQL) + _, err = Execute(ctx, qctx, testSQL) c.Assert(err, IsNil) colInfos, err := qctx.FieldList("t") c.Assert(err, IsNil) @@ -878,15 +889,15 @@ func (ts *tidbTestSuite) TestFieldList(c *C) { tooLongColumnAsName := "COALESCE(0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0)" columnAsName := tooLongColumnAsName[:tmysql.MaxAliasIdentifierLen] - rs, err := qctx.Execute(ctx, "select "+tooLongColumnAsName) + rs, err := Execute(ctx, qctx, "select "+tooLongColumnAsName) c.Assert(err, IsNil) - cols := rs[0].Columns() + cols := rs.Columns() c.Assert(cols[0].OrgName, Equals, tooLongColumnAsName) c.Assert(cols[0].Name, Equals, columnAsName) - rs, err = qctx.Execute(ctx, "select c_bit as '"+tooLongColumnAsName+"' from t") + rs, err = Execute(ctx, qctx, "select c_bit as '"+tooLongColumnAsName+"' from t") c.Assert(err, IsNil) - cols = rs[0].Columns() + cols = rs.Columns() c.Assert(cols[0].OrgName, Equals, "c_bit") c.Assert(cols[0].Name, Equals, columnAsName) } @@ -902,9 +913,9 @@ func (ts *tidbTestSuite) TestNullFlag(c *C) { c.Assert(err, IsNil) ctx := context.Background() - rs, err := qctx.Execute(ctx, "select 1") + rs, err := Execute(ctx, qctx, "select 1") c.Assert(err, IsNil) - cols := rs[0].Columns() + cols := rs.Columns() c.Assert(len(cols), Equals, 1) expectFlag := uint16(tmysql.NotNullFlag | tmysql.BinaryFlag) c.Assert(dumpFlag(cols[0].Type, cols[0].Flag), Equals, expectFlag) diff --git a/session/session.go b/session/session.go index 31b8aa2568..7b7db75e1d 100644 --- a/session/session.go +++ b/session/session.go @@ -95,13 +95,16 @@ var ( // Session context, it is consistent with the lifecycle of a client connection. type Session interface { sessionctx.Context - Status() uint16 // Flag of current status, such as autocommit. - LastInsertID() uint64 // LastInsertID is the last inserted auto_increment ID. - LastMessage() string // LastMessage is the info message that may be generated by last command - AffectedRows() uint64 // Affected rows by latest executed stmt. + Status() uint16 // Flag of current status, such as autocommit. + LastInsertID() uint64 // LastInsertID is the last inserted auto_increment ID. + LastMessage() string // LastMessage is the info message that may be generated by last command + AffectedRows() uint64 // Affected rows by latest executed stmt. + // Execute is deprecated, use ExecuteStmt() instead. Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement. ExecuteInternal(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a internal sql statement. - String() string // String is used to debug. + ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error) + Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) + String() string // String is used to debug. CommitTxn(context.Context) error RollbackTxn(context.Context) // PrepareStmt executes prepare statement in binary protocol. @@ -1088,6 +1091,99 @@ func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec return } +// Parse parses a query string to raw ast.StmtNode. +func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) { + charsetInfo, collation := s.sessionVars.GetCharsetInfo() + parseStartTime := time.Now() + stmts, warns, err := s.ParseSQL(ctx, sql, charsetInfo, collation) + if err != nil { + s.rollbackOnError(ctx) + + // Only print log message when this SQL is from the user. + // Mute the warning for internal SQLs. + if !s.sessionVars.InRestrictedSQL { + logutil.Logger(ctx).Warn("parse SQL failed", zap.Error(err), zap.String("SQL", sql)) + } + return nil, util.SyntaxError(err) + } + + durParse := time.Since(parseStartTime) + s.GetSessionVars().DurationParse = durParse + isInternal := s.isInternal() + if isInternal { + sessionExecuteParseDurationInternal.Observe(durParse.Seconds()) + } else { + sessionExecuteParseDurationGeneral.Observe(durParse.Seconds()) + } + for _, warn := range warns { + s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + return stmts, nil +} + +func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("session.ExecuteStmt", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + s.PrepareTxnCtx(ctx) + err := s.loadCommonGlobalVariablesIfNeeded() + if err != nil { + return nil, err + } + + s.sessionVars.StartTime = time.Now() + + // Some executions are done in compile stage, so we reset them before compile. + if err := executor.ResetContextOfStmt(s, stmtNode); err != nil { + return nil, err + } + + // Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt). + compiler := executor.Compiler{Ctx: s} + stmt, err := compiler.Compile(ctx, stmtNode) + if err != nil { + s.rollbackOnError(ctx) + + // Only print log message when this SQL is from the user. + // Mute the warning for internal SQLs. + if !s.sessionVars.InRestrictedSQL { + logutil.Logger(ctx).Warn("compile SQL failed", zap.Error(err), zap.String("SQL", stmtNode.Text())) + } + return nil, err + } + durCompile := time.Since(s.sessionVars.StartTime) + s.GetSessionVars().DurationCompile = durCompile + if s.isInternal() { + sessionExecuteCompileDurationInternal.Observe(durCompile.Seconds()) + } else { + sessionExecuteCompileDurationGeneral.Observe(durCompile.Seconds()) + } + s.currentPlan = stmt.Plan + + // Execute the physical plan. + logStmt(stmtNode, s.sessionVars) + startTime := time.Now() + recordSet, err := runStmt(ctx, s, stmt) + if err != nil { + if !kv.ErrKeyExists.Equal(err) { + logutil.Logger(ctx).Warn("run statement failed", + zap.Int64("schemaVersion", s.sessionVars.TxnCtx.SchemaVersion), + zap.Error(err), + zap.String("session", s.String())) + } + return nil, err + } + if s.isInternal() { + sessionExecuteRunDurationInternal.Observe(time.Since(startTime).Seconds()) + } else { + sessionExecuteRunDurationGeneral.Observe(time.Since(startTime).Seconds()) + } + return recordSet, nil +} + func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { s.PrepareTxnCtx(ctx) connID := s.sessionVars.ConnectionID