diff --git a/session.go b/session.go index 8662ecd6a9..908a994233 100644 --- a/session.go +++ b/session.go @@ -137,6 +137,7 @@ func (s *session) FinishTxn(rollback bool) error { } defer func() { s.txn = nil + variable.GetSessionVars(s).SetStatusInTrans(false) }() if rollback { @@ -345,6 +346,7 @@ func (s *session) GetTxn(forceNew bool) (kv.Transaction, error) { } if forceNew { err = s.txn.Commit() + variable.GetSessionVars(s).SetStatusInTrans(false) if err != nil { return nil, err } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 9344c9b49a..fdfc9c7916 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -38,9 +38,6 @@ type SessionVars struct { // Client Capability ClientCapability uint32 // Client capability - // Disable autocommit - DisableAutocommit bool - // Found rows FoundRows uint64 } @@ -102,19 +99,29 @@ func (s *SessionVars) SetStatus(status uint16) { s.Status = status } +// SetStatusInTrans sets the status flags about ServerStatusInTrans. +func (s *SessionVars) SetStatusInTrans(isInTrans bool) { + if isInTrans { + s.Status |= mysql.ServerStatusInTrans + return + } + s.Status &= (^mysql.ServerStatusInTrans) +} + // GetNextPreparedStmtID generates and return the next session scope prepared statement id func (s *SessionVars) GetNextPreparedStmtID() uint32 { s.preparedStmtID++ return s.preparedStmtID } -// IsAutocommit checks if it is in autocommit enviroment -func IsAutocommit(ctx context.Context) bool { +// ShouldAutocommit checks if it is in autocommit enviroment +func ShouldAutocommit(ctx context.Context) bool { // With START TRANSACTION, autocommit remains disabled until you end // the transaction with COMMIT or ROLLBACK. if GetSessionVars(ctx).Status&mysql.ServerStatusAutocommit > 0 && - !GetSessionVars(ctx).DisableAutocommit { + GetSessionVars(ctx).Status&mysql.ServerStatusInTrans == 0 { return true } + return false } diff --git a/stmt/stmts/select.go b/stmt/stmts/select.go index 6782f91862..1c930365b1 100644 --- a/stmt/stmts/select.go +++ b/stmt/stmts/select.go @@ -147,7 +147,7 @@ func (s *SelectStmt) Plan(ctx context.Context) (plan.Plan, error) { } } lock := s.Lock - if variable.IsAutocommit(ctx) { + if variable.ShouldAutocommit(ctx) { // Locking of rows for update using SELECT FOR UPDATE only applies when autocommit // is disabled (either by beginning transaction with START TRANSACTION or by setting // autocommit to 0. If autocommit is enabled, the rows matching the specification are not locked. diff --git a/stmt/stmts/transaction.go b/stmt/stmts/transaction.go index 2b5c8e7a1c..4780a5ad9f 100644 --- a/stmt/stmts/transaction.go +++ b/stmt/stmts/transaction.go @@ -63,7 +63,7 @@ func (s *BeginStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { // With START TRANSACTION, autocommit remains disabled until you end // the transaction with COMMIT or ROLLBACK. The autocommit mode then // reverts to its previous state. - variable.GetSessionVars(ctx).DisableAutocommit = true + variable.GetSessionVars(ctx).SetStatusInTrans(true) return } @@ -96,7 +96,7 @@ func (s *CommitStmt) SetText(text string) { // Exec implements the stmt.Statement Exec interface. func (s *CommitStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { err = ctx.FinishTxn(false) - variable.GetSessionVars(ctx).DisableAutocommit = false + variable.GetSessionVars(ctx).SetStatusInTrans(false) return } @@ -129,6 +129,6 @@ func (s *RollbackStmt) SetText(text string) { // Exec implements the stmt.Statement Exec interface. func (s *RollbackStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { err = ctx.FinishTxn(true) - variable.GetSessionVars(ctx).DisableAutocommit = false + variable.GetSessionVars(ctx).SetStatusInTrans(false) return } diff --git a/tidb.go b/tidb.go index b6542b9c1a..c55dc9df31 100644 --- a/tidb.go +++ b/tidb.go @@ -161,7 +161,7 @@ func runStmt(ctx context.Context, s stmt.Statement, args ...interface{}) (rset.R stmt.ClearExecArgs(ctx) } // MySQL DDL should be auto-commit - if err == nil && (s.IsDDL() || variable.IsAutocommit(ctx)) { + if err == nil && (s.IsDDL() || variable.ShouldAutocommit(ctx)) { err = ctx.FinishTxn(false) } return rs, errors.Trace(err) diff --git a/tidb_test.go b/tidb_test.go index dcb25e18ae..a24283feea 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/kv" mysql "github.com/pingcap/tidb/mysqldef" "github.com/pingcap/tidb/rset" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/errors2" ) @@ -462,6 +463,41 @@ func (s *testSessionSuite) TestAutoincrementID(c *C) { mustExecSQL(c, se, s.dropDBSQL) } +func checkInTrans(c *C, se Session, stmt string, expect uint16) { + mustExecSQL(c, se, stmt) + if expect == 0 { + c.Assert(se.(*session).txn, IsNil) + } else { + c.Assert(se.(*session).txn, NotNil) + } + ret := variable.GetSessionVars(se.(*session)).Status & mysql.ServerStatusInTrans + c.Assert(ret, Equals, expect) +} + +// See: https://dev.mysql.com/doc/internals/en/status-flags.html +func (s *testSessionSuite) TestInTrans(c *C) { + store := newStore(c, s.dbName) + se := newSession(c, store, s.dbName) + checkInTrans(c, se, "drop table if exists t;", 0) + checkInTrans(c, se, "create table t (id BIGINT PRIMARY KEY AUTO_INCREMENT NOT NULL)", 0) + checkInTrans(c, se, "insert t values ()", 0) + checkInTrans(c, se, "begin", 1) + checkInTrans(c, se, "insert t values ()", 1) + checkInTrans(c, se, "drop table if exists t;", 0) + checkInTrans(c, se, "create table t (id BIGINT PRIMARY KEY AUTO_INCREMENT NOT NULL)", 0) + checkInTrans(c, se, "insert t values ()", 0) + checkInTrans(c, se, "commit", 0) + checkInTrans(c, se, "insert t values ()", 0) + + checkInTrans(c, se, "drop table if exists t;", 0) + checkInTrans(c, se, "create table t (id BIGINT PRIMARY KEY AUTO_INCREMENT NOT NULL)", 0) + checkInTrans(c, se, "begin", 1) + checkInTrans(c, se, "insert t values ()", 1) + checkInTrans(c, se, "rollback", 0) + + mustExecSQL(c, se, s.dropDBSQL) +} + // See: http://dev.mysql.com/doc/refman/5.7/en/commit.html func (s *testSessionSuite) TestRowLock(c *C) { store := newStore(c, s.dbName)