diff --git a/driver.go b/driver.go index b41c784a9b..5c74391b65 100644 --- a/driver.go +++ b/driver.go @@ -29,11 +29,13 @@ import ( "sync" "github.com/juju/errors" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" mysql "github.com/pingcap/tidb/mysqldef" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx" qerror "github.com/pingcap/tidb/util/errors" + "github.com/pingcap/tidb/util/errors2" "github.com/pingcap/tidb/util/types" ) @@ -239,8 +241,13 @@ func (c *driverConn) Commit() error { if c.s == nil { return qerror.ErrCommitNotInTransaction } + _, err := c.s.Execute(txCommitSQL) - if _, err := c.s.Execute(txCommitSQL); err != nil { + if errors2.ErrorEqual(err, kv.ErrConditionNotMatch) { + return c.s.Retry() + } + + if err != nil { return err } diff --git a/kv/index_iter.go b/kv/index_iter.go index 4e057c6125..78cac4cf77 100644 --- a/kv/index_iter.go +++ b/kv/index_iter.go @@ -23,8 +23,6 @@ import ( "github.com/juju/errors" ) -// Cockroach sql index implementation - var ( _ Index = (*kvIndex)(nil) _ IndexIterator = (*IndexIter)(nil) @@ -138,7 +136,7 @@ func (c *kvIndex) genIndexKey(indexedValues []interface{}, h int64) ([]byte, err } // Create creates a new entry in the kvIndex data. -// If the index is unique and there already exists an entry with the same key, Create will return ErrConditionNotMatch +// If the index is unique and there already exists an entry with the same key, Create will return ErrKeyExists func (c *kvIndex) Create(txn Transaction, indexedValues []interface{}, h int64) error { keyBuf, err := c.genIndexKey(indexedValues, h) if err != nil { @@ -157,7 +155,7 @@ func (c *kvIndex) Create(txn Transaction, indexedValues []interface{}, h int64) return errors.Trace(err) } - return errors.Trace(ErrConditionNotMatch) + return errors.Trace(ErrKeyExists) } // Delete removes the entry for handle h and indexdValues from KV index. diff --git a/kv/iter.go b/kv/iter.go index b914ff44cf..09ec2197f7 100644 --- a/kv/iter.go +++ b/kv/iter.go @@ -24,6 +24,8 @@ var ( ErrClosed = errors.New("Error: Transaction already closed") // ErrNotExist is used when try to get an entry with an unexist key from KV store. ErrNotExist = errors.New("Error: key not exist") + // ErrKeyExists is used when try to put an entry to KV store. + ErrKeyExists = errors.New("Error: key already exist") // ErrConditionNotMatch is used when condition is not met. ErrConditionNotMatch = errors.New("Error: Condition not match") // ErrLockConflict is used when try to lock an already locked key. diff --git a/plan/plans/lock.go b/plan/plans/lock.go index 703637a8a4..6762e870f5 100644 --- a/plan/plans/lock.go +++ b/plan/plans/lock.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/plan" + "github.com/pingcap/tidb/sessionctx/forupdate" "github.com/pingcap/tidb/util/format" ) @@ -48,6 +49,7 @@ func (r *SelectLockPlan) Do(ctx context.Context, f plan.RowIterFunc) error { } } if rowKeys != nil && r.Lock == coldef.SelectLockForUpdate { + forupdate.SetForUpdate(ctx) txn, err := ctx.GetTxn(false) if err != nil { return false, errors.Trace(err) diff --git a/session.go b/session.go index b22546357f..3a1f124268 100644 --- a/session.go +++ b/session.go @@ -31,7 +31,11 @@ import ( "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/db" + "github.com/pingcap/tidb/sessionctx/forupdate" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/stmt" + "github.com/pingcap/tidb/stmt/stmts" + "github.com/pingcap/tidb/util/errors2" "github.com/pingcap/tidb/util/types" ) @@ -51,6 +55,7 @@ type Session interface { DropPreparedStmt(stmtID uint32) error SetClientCapability(uint32) // Set client capability flags Close() error + Retry() error } var ( @@ -58,14 +63,47 @@ var ( sessionID int64 ) +type stmtRecord struct { + stmtID uint32 + st stmt.Statement + params []interface{} +} + +type stmtHistory struct { + history []*stmtRecord +} + +func (h *stmtHistory) add(stmtID uint32, st stmt.Statement, params ...interface{}) { + s := &stmtRecord{ + stmtID: stmtID, + st: st, + params: append(([]interface{})(nil), params...), + } + + h.history = append(h.history, s) +} + +func (h *stmtHistory) reset() { + if len(h.history) > 0 { + h.history = h.history[:0] + } +} + +func (h *stmtHistory) clone() *stmtHistory { + nh := *h + nh.history = make([]*stmtRecord, len(h.history)) + copy(nh.history, h.history) + return &nh +} + type session struct { txn kv.Transaction // Current transaction userName string args []interface{} // Statment execution args, this should be cleaned up after exec - - values map[fmt.Stringer]interface{} - store kv.Storage - sid int64 + values map[fmt.Stringer]interface{} + store kv.Storage + sid int64 + history stmtHistory } func (s *session) Status() uint16 { @@ -84,6 +122,11 @@ func (s *session) SetUsername(name string) { s.userName = name } +func (s *session) resetHistory() { + s.ClearValue(forupdate.ForUpdateKey) + s.history.reset() +} + func (s *session) SetClientCapability(capability uint32) { variable.GetSessionVars(s).ClientCapability = capability } @@ -103,10 +146,12 @@ func (s *session) FinishTxn(rollback bool) error { err := s.txn.Commit() if err != nil { - log.Errorf("txn:%s, %v", s.txn, err) + log.Warnf("txn:%s, %v", s.txn, err) + return errors.Trace(err) } - return errors.Trace(err) + s.resetHistory() + return nil } func (s *session) String() string { @@ -126,8 +171,66 @@ func (s *session) String() string { return string(b) } +func needRetry(st stmt.Statement) bool { + switch st.(type) { + case *stmts.PreparedStmt, *stmts.ShowStmt, *stmts.DoStmt: + return false + default: + return true + } +} + +func isPreparedStmt(st stmt.Statement) bool { + switch st.(type) { + case *stmts.PreparedStmt: + return true + default: + return false + } +} + +func (s *session) Retry() error { + nh := s.history.clone() + defer func() { + s.history.history = nh.history + }() + + if forUpdate := s.Value(forupdate.ForUpdateKey); forUpdate != nil { + return errors.Errorf("can not retry select for update statement") + } + + var err error + for { + s.resetHistory() + s.FinishTxn(true) + success := true + for _, sr := range nh.history { + st := sr.st + // Skip prepare statement + if !needRetry(st) { + continue + } + log.Warnf("Retry %s", st.OriginText()) + _, err = runStmt(s, st) + if err != nil { + if errors2.ErrorEqual(err, kv.ErrConditionNotMatch) { + success = false + break + } + log.Warnf("session:%v, err:%v", s, err) + return errors.Trace(err) + } + } + if success { + return nil + } + } + + return nil +} + func (s *session) Execute(sql string) ([]rset.Recordset, error) { - stmts, err := Compile(sql) + statements, err := Compile(sql) if err != nil { log.Errorf("Syntax error: %s", sql) log.Errorf("Error occurs at %s.", err) @@ -136,13 +239,21 @@ func (s *session) Execute(sql string) ([]rset.Recordset, error) { var rs []rset.Recordset - for _, si := range stmts { - r, err := runStmt(s, si) + for _, st := range statements { + r, err := runStmt(s, st) if err != nil { log.Warnf("session:%v, err:%v", s, err) return nil, errors.Trace(err) } + // Record executed query + if isPreparedStmt(st) { + ps := st.(*stmts.PreparedStmt) + s.history.add(ps.ID, st) + } else { + s.history.add(0, st) + } + if r != nil { rs = append(rs, r) } @@ -183,12 +294,10 @@ func (s *session) ExecutePreparedStmt(stmtID uint32, args ...interface{}) (rset. if err != nil { return nil, err } - //convert args to param - rs, err := executePreparedStmt(s, stmtID, args...) - if err != nil { - return nil, err - } - return rs, nil + + st := &stmts.ExecuteStmt{ID: stmtID} + s.history.add(stmtID, st, args...) + return runStmt(s, st, args...) } func (s *session) DropPreparedStmt(stmtID uint32) error { @@ -201,6 +310,7 @@ func (s *session) DropPreparedStmt(stmtID uint32) error { func (s *session) GetTxn(forceNew bool) (kv.Transaction, error) { var err error if s.txn == nil { + s.resetHistory() s.txn, err = s.store.Begin() if err != nil { return nil, err @@ -214,6 +324,7 @@ func (s *session) GetTxn(forceNew bool) (kv.Transaction, error) { if err != nil { return nil, err } + s.resetHistory() s.txn, err = s.store.Begin() if err != nil { return nil, err diff --git a/sessionctx/forupdate/for_update_ctx.go b/sessionctx/forupdate/for_update_ctx.go new file mode 100644 index 0000000000..f413bd1c50 --- /dev/null +++ b/sessionctx/forupdate/for_update_ctx.go @@ -0,0 +1,33 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package forupdate record information for "select ... for update" statement +package forupdate + +import "github.com/pingcap/tidb/context" + +// A dummy type to avoid naming collision in context. +type forupdateKeyType int + +// String defines a Stringer function for debugging and pretty printing. +func (k forupdateKeyType) String() string { + return "for update" +} + +// ForUpdateKey is used to retrive "select for update" statement information +const ForUpdateKey forupdateKeyType = 0 + +// SetForUpdate set "select for update" flag. +func SetForUpdate(ctx context.Context) { + ctx.SetValue(ForUpdateKey, true) +} diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index f75f71e7a7..ceb21d3701 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -125,7 +125,6 @@ func (s *InsertIntoStmt) execSelect(t table.Table, cols []*column.Col, ctx conte for i, r := range bufRecords { variable.GetSessionVars(ctx).SetLastInsertID(lastInsertIds[i]) - if _, err = t.AddRecord(ctx, r); err != nil { return nil, errors.Trace(err) } @@ -284,7 +283,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) if err == nil { continue } - if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrConditionNotMatch) { + if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { return nil, errors.Trace(err) } // On duplicate key Update the duplicate row. diff --git a/table/tables/tables.go b/table/tables/tables.go index 07b645097b..ed737f1452 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -348,7 +348,7 @@ func (t *Table) AddRecord(ctx context.Context, r []interface{}) (recordID int64, } colVals, _ := v.FetchValues(r) if err = v.X.Create(txn, colVals, recordID); err != nil { - if errors2.ErrorEqual(err, kv.ErrConditionNotMatch) { + if errors2.ErrorEqual(err, kv.ErrKeyExists) { // Get the duplicate row handle iter, _, terr := v.X.Seek(txn, colVals) if terr != nil { diff --git a/tidb_test.go b/tidb_test.go index 746d9d50cb..a269c1af8b 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -28,6 +28,7 @@ import ( mysql "github.com/pingcap/tidb/mysqldef" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/errors2" ) var store = flag.String("store", "memory", "registered store name, [memory, goleveldb, boltdb]") @@ -500,7 +501,6 @@ func (s *testSessionSuite) TestAutoicommit(c *C) { // See: http://dev.mysql.com/doc/refman/5.7/en/commit.html func (s *testSessionSuite) TestRowLock(c *C) { - c.Skip("Need retry feature") store := newStore(c, s.dbName) se := newSession(c, store, s.dbName) se1 := newSession(c, store, s.dbName) @@ -522,6 +522,11 @@ func (s *testSessionSuite) TestRowLock(c *C) { _, err := exec(c, se1, "commit") // row lock conflict but can still success + if errors2.ErrorNotEqual(err, kv.ErrConditionNotMatch) { + c.Fail() + } + // Retry should success + err = se.Retry() c.Assert(err, IsNil) mustExecSQL(c, se1, "begin") @@ -560,6 +565,9 @@ func (s *testSessionSuite) TestSelectForUpdate(c *C) { _, err = exec(c, se1, "commit") c.Assert(err, NotNil) + err = se1.Retry() + // retry should fail + c.Assert(err, NotNil) // not conflict mustExecSQL(c, se1, "begin")