diff --git a/executor/adapter.go b/executor/adapter.go index 4548f2058e..fbc8657b57 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -80,11 +80,17 @@ func (a *statement) Exec(ctx context.Context) (ast.RecordSet, error) { defer e.Close() for { row, err := e.Next() - if err != nil || row == nil { + if err != nil { return nil, errors.Trace(err) } + if row == nil { + // It's used to insert retry. + changeInsertValueForRetry(a.plan, e) + return nil, nil + } } } + fs := e.Fields() for _, f := range fs { if len(f.ColumnAsName.O) == 0 { @@ -96,3 +102,19 @@ func (a *statement) Exec(ctx context.Context) (ast.RecordSet, error) { fields: fs, }, nil } + +func changeInsertValueForRetry(p plan.Plan, e Executor) { + if v, ok := p.(*plan.Insert); ok { + var insertValue *InsertValues + if !v.IsReplace { + insertValue = e.(*InsertExec).InsertValues + } else { + insertValue = e.(*ReplaceExec).InsertValues + } + v.Columns = insertValue.Columns + v.Setlist = insertValue.Setlist + if len(v.Setlist) == 0 { + v.Lists = insertValue.Lists + } + } +} diff --git a/executor/executor_write.go b/executor/executor_write.go index c1eb73a947..784028117b 100644 --- a/executor/executor_write.go +++ b/executor/executor_write.go @@ -362,6 +362,7 @@ func (e *DeleteExec) Close() error { // InsertValues is the data to insert. type InsertValues struct { + currRow int ctx context.Context SelectExec Executor @@ -553,10 +554,12 @@ func (e *InsertValues) getRows(cols []*column.Col) (rows [][]types.Datum, err er } rows = make([][]types.Datum, len(e.Lists)) + length := len(e.Lists[0]) for i, list := range e.Lists { - if err = e.checkValueCount(len(e.Lists[0]), len(list), i, cols); err != nil { + if err = e.checkValueCount(length, len(list), i, cols); err != nil { return nil, errors.Trace(err) } + e.currRow = i rows[i], err = e.getRow(cols, list, defaultVals) if err != nil { return nil, errors.Trace(err) @@ -606,6 +609,7 @@ func (e *InsertValues) getRowsSelect(cols []*column.Col) ([][]types.Datum, error if innerRow == nil { break } + e.currRow = len(rows) row, err := e.fillRowData(cols, innerRow.Data) if err != nil { return nil, errors.Trace(err) @@ -637,6 +641,7 @@ func (e *InsertValues) fillRowData(cols []*column.Col, vals []types.Datum) ([]ty } func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struct{}) error { + var rewriteValueCol *column.Col var defaultValueCols []*column.Col for i, c := range e.Table.Cols() { if row[i].Kind() != types.KindNull { @@ -662,6 +667,8 @@ func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struc // `insert t (c1) values(1),(2),(3);` // Last insert id will be 1, not 3. variable.GetSessionVars(e.ctx).SetLastInsertID(uint64(recordID)) + // It's used for retry. + rewriteValueCol = c } } else { var err error @@ -676,6 +683,50 @@ func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struc if err := column.CastValues(e.ctx, row, defaultValueCols); err != nil { return errors.Trace(err) } + + // It's used for retry. + if rewriteValueCol == nil { + return nil + } + if len(e.Setlist) > 0 { + val := &ast.Assignment{ + Column: &ast.ColumnName{Name: rewriteValueCol.Name}, + Expr: ast.NewValueExpr(row[rewriteValueCol.Offset].GetValue())} + if len(e.Setlist) < rewriteValueCol.Offset+1 { + e.Setlist = append(e.Setlist, val) + return nil + } + setlist := make([]*ast.Assignment, 0, len(e.Setlist)+1) + setlist = append(setlist, e.Setlist[:rewriteValueCol.Offset]...) + setlist = append(setlist, val) + e.Setlist = append(setlist, e.Setlist[rewriteValueCol.Offset:]...) + return nil + } + + // records the values of each row. + vals := make([]ast.ExprNode, len(row)) + for i, col := range row { + vals[i] = ast.NewValueExpr(col.GetValue()) + } + if len(e.Lists) <= e.currRow { + e.Lists = append(e.Lists, vals) + } else { + e.Lists[e.currRow] = vals + } + + // records the column name only once. + if e.currRow != len(e.Lists)-1 { + return nil + } + if len(e.Columns) < rewriteValueCol.Offset+1 { + e.Columns = append(e.Columns, &ast.ColumnName{Name: rewriteValueCol.Name}) + return nil + } + cols := make([]*ast.ColumnName, 0, len(e.Columns)+1) + cols = append(cols, e.Columns[:rewriteValueCol.Offset]...) + cols = append(cols, &ast.ColumnName{Name: rewriteValueCol.Name}) + e.Columns = append(cols, e.Columns[rewriteValueCol.Offset:]...) + return nil } diff --git a/session_test.go b/session_test.go index fa7e20cc69..d3e7fd31e1 100644 --- a/session_test.go +++ b/session_test.go @@ -343,6 +343,106 @@ func (s *testSessionSuite) TestRowLock(c *C) { mustExecSQL(c, se, s.dropDBSQL) } +func (s *testSessionSuite) TestIssue827(c *C) { + store := newStore(c, s.dbName) + se := newSession(c, store, s.dbName) + se1 := newSession(c, store, s.dbName) + + mustExecSQL(c, se, "drop table if exists t1") + c.Assert(se.(*session).txn, IsNil) + mustExecSQL(c, se, "create table t1 (c2 int, c3 int, c1 int not null auto_increment, PRIMARY KEY (c1))") + mustExecSQL(c, se, "insert into t1 set c2 = 30") + + mustExecSQL(c, se, "drop table if exists t") + c.Assert(se.(*session).txn, IsNil) + mustExecSQL(c, se, "create table t (c2 int, c1 int not null auto_increment, PRIMARY KEY (c1))") + mustExecSQL(c, se, "insert into t (c2) values (1), (2), (3)") + + // insert values + lastInsertID := se.LastInsertID() + mustExecSQL(c, se, "begin") + mustExecSQL(c, se, "insert into t (c2) values (11), (12), (13)") + rs, err := exec(c, se, "select c1 from t where c2 = 11") + c.Assert(err, IsNil) + expect, err := GetRows(rs) + c.Assert(err, IsNil) + _, err = exec(c, se, "update t set c2 = 33 where c2 = 1") + c.Assert(err, IsNil) + + mustExecSQL(c, se1, "begin") + mustExecSQL(c, se1, "update t set c2 = 22 where c2 = 1") + mustExecSQL(c, se1, "commit") + + _, err = exec(c, se, "commit") + c.Assert(err, IsNil) + + rs, err = exec(c, se, "select c1 from t where c2 = 11") + c.Assert(err, IsNil) + r, err := GetRows(rs) + c.Assert(err, IsNil) + c.Assert(r, DeepEquals, expect) + currLastInsertID := se.LastInsertID() + c.Assert(lastInsertID+3, Equals, currLastInsertID) + + // insert set + lastInsertID = se.LastInsertID() + mustExecSQL(c, se, "begin") + mustExecSQL(c, se, "insert into t set c2 = 31") + rs, err = exec(c, se, "select c1 from t where c2 = 31") + c.Assert(err, IsNil) + expect, err = GetRows(rs) + c.Assert(err, IsNil) + _, err = exec(c, se, "update t set c2 = 44 where c2 = 2") + c.Assert(err, IsNil) + + mustExecSQL(c, se1, "begin") + mustExecSQL(c, se1, "update t set c2 = 55 where c2 = 2") + mustExecSQL(c, se1, "commit") + + _, err = exec(c, se, "commit") + c.Assert(err, IsNil) + + rs, err = exec(c, se, "select c1 from t where c2 = 31") + c.Assert(err, IsNil) + r, err = GetRows(rs) + c.Assert(err, IsNil) + c.Assert(r, DeepEquals, expect) + currLastInsertID = se.LastInsertID() + c.Assert(lastInsertID+1, Equals, currLastInsertID) + + // replace + lastInsertID = se.LastInsertID() + mustExecSQL(c, se, "begin") + mustExecSQL(c, se, "insert into t (c2) values (21), (22), (23)") + rs, err = exec(c, se, "select c1 from t where c2 = 21") + c.Assert(err, IsNil) + expect, err = GetRows(rs) + c.Assert(err, IsNil) + _, err = exec(c, se, "update t set c2 = 66 where c2 = 3") + c.Assert(err, IsNil) + + mustExecSQL(c, se1, "begin") + mustExecSQL(c, se1, "update t set c2 = 77 where c2 = 3") + mustExecSQL(c, se1, "commit") + + _, err = exec(c, se, "commit") + c.Assert(err, IsNil) + + rs, err = exec(c, se, "select c1 from t where c2 = 21") + c.Assert(err, IsNil) + r, err = GetRows(rs) + c.Assert(err, IsNil) + c.Assert(r, DeepEquals, expect) + currLastInsertID = se.LastInsertID() + c.Assert(lastInsertID+3, Equals, currLastInsertID) + + mustExecSQL(c, se, s.dropDBSQL) + err = se.Close() + c.Assert(err, IsNil) + err = se1.Close() + c.Assert(err, IsNil) +} + func (s *testSessionSuite) TestSelectForUpdate(c *C) { store := newStore(c, s.dbName) se := newSession(c, store, s.dbName)