From ad13ca15ca6f73ee24fc9f7efe8a771087bf047f Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 29 Mar 2017 16:09:47 +0800 Subject: [PATCH] session: reset affected rows in retry. (#2949) --- session.go | 1 + session_test.go | 19 +++++++++++++++++++ sessionctx/variable/session.go | 9 +++++++++ sessionctx/variable/session_test.go | 5 +++++ 4 files changed, 34 insertions(+) diff --git a/session.go b/session.go index 3e75211bbb..b6a128af63 100644 --- a/session.go +++ b/session.go @@ -353,6 +353,7 @@ func (s *session) retry(maxCnt int) error { log.Warnf("[%d] Retry [%d] query [%d]", connID, retryCnt, i) } s.sessionVars.StmtCtx = sr.stmtCtx + s.sessionVars.StmtCtx.ResetForRetry() _, err = st.Exec(s) if err != nil { break diff --git a/session_test.go b/session_test.go index 23507ba267..727c0ca966 100644 --- a/session_test.go +++ b/session_test.go @@ -2692,3 +2692,22 @@ func (s *testSessionSuite) TestRetryCleanTxn(c *C) { c.Assert(se.Txn(), IsNil) c.Assert(se.sessionVars.InTxn(), IsFalse) } + +func (s *testSessionSuite) TestRetryResetStmtCtx(c *C) { + defer testleak.AfterTest(c)() + dbName := "test_retry_reset_stmtctx" + se := newSession(c, s.store, dbName).(*session) + se.Execute("create table retrytxn (a int unique, b int)") + _, err := se.Execute("insert retrytxn values (1, 1)") + c.Assert(err, IsNil) + se.Execute("begin") + se.Execute("update retrytxn set b = b + 1 where a = 1") + + // Make retryable error. + se2 := newSession(c, s.store, dbName) + se2.Execute("update retrytxn set b = b + 1 where a = 1") + + err = se.CommitTxn() + c.Assert(err, IsNil) + c.Assert(se.AffectedRows(), Equals, uint64(1)) +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 84a9db382e..cf69865350 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -368,3 +368,12 @@ func (sc *StatementContext) HandleTruncate(err error) error { } return err } + +// ResetForRetry resets the changed states during execution. +func (sc *StatementContext) ResetForRetry() { + sc.mu.Lock() + sc.mu.affectedRows = 0 + sc.mu.foundRows = 0 + sc.mu.warnings = nil + sc.mu.Unlock() +} diff --git a/sessionctx/variable/session_test.go b/sessionctx/variable/session_test.go index b33dea7bcd..340c175978 100644 --- a/sessionctx/variable/session_test.go +++ b/sessionctx/variable/session_test.go @@ -44,4 +44,9 @@ func (*testSessionSuite) TestSession(c *C) { // For last insert id ctx.GetSessionVars().SetLastInsertID(1) c.Assert(ctx.GetSessionVars().LastInsertID, Equals, uint64(1)) + + ss.ResetForRetry() + c.Assert(ss.AffectedRows(), Equals, uint64(0)) + c.Assert(ss.FoundRows(), Equals, uint64(0)) + c.Assert(ss.WarningCount(), Equals, uint16(0)) }