Merge pull request #957 from pingcap/zimuxia/issue-827

*: Fix issue #827
This commit is contained in:
zimulala
2016-03-09 17:01:57 +08:00
3 changed files with 175 additions and 2 deletions

View File

@ -80,11 +80,17 @@ func (a *statement) Exec(ctx context.Context) (ast.RecordSet, error) {
defer e.Close()
for {
row, err := e.Next()
if err != nil || row == nil {
if err != nil {
return nil, errors.Trace(err)
}
if row == nil {
// It's used to insert retry.
changeInsertValueForRetry(a.plan, e)
return nil, nil
}
}
}
fs := e.Fields()
for _, f := range fs {
if len(f.ColumnAsName.O) == 0 {
@ -96,3 +102,19 @@ func (a *statement) Exec(ctx context.Context) (ast.RecordSet, error) {
fields: fs,
}, nil
}
func changeInsertValueForRetry(p plan.Plan, e Executor) {
if v, ok := p.(*plan.Insert); ok {
var insertValue *InsertValues
if !v.IsReplace {
insertValue = e.(*InsertExec).InsertValues
} else {
insertValue = e.(*ReplaceExec).InsertValues
}
v.Columns = insertValue.Columns
v.Setlist = insertValue.Setlist
if len(v.Setlist) == 0 {
v.Lists = insertValue.Lists
}
}
}

View File

@ -362,6 +362,7 @@ func (e *DeleteExec) Close() error {
// InsertValues is the data to insert.
type InsertValues struct {
currRow int
ctx context.Context
SelectExec Executor
@ -553,10 +554,12 @@ func (e *InsertValues) getRows(cols []*column.Col) (rows [][]types.Datum, err er
}
rows = make([][]types.Datum, len(e.Lists))
length := len(e.Lists[0])
for i, list := range e.Lists {
if err = e.checkValueCount(len(e.Lists[0]), len(list), i, cols); err != nil {
if err = e.checkValueCount(length, len(list), i, cols); err != nil {
return nil, errors.Trace(err)
}
e.currRow = i
rows[i], err = e.getRow(cols, list, defaultVals)
if err != nil {
return nil, errors.Trace(err)
@ -606,6 +609,7 @@ func (e *InsertValues) getRowsSelect(cols []*column.Col) ([][]types.Datum, error
if innerRow == nil {
break
}
e.currRow = len(rows)
row, err := e.fillRowData(cols, innerRow.Data)
if err != nil {
return nil, errors.Trace(err)
@ -637,6 +641,7 @@ func (e *InsertValues) fillRowData(cols []*column.Col, vals []types.Datum) ([]ty
}
func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struct{}) error {
var rewriteValueCol *column.Col
var defaultValueCols []*column.Col
for i, c := range e.Table.Cols() {
if row[i].Kind() != types.KindNull {
@ -662,6 +667,8 @@ func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struc
// `insert t (c1) values(1),(2),(3);`
// Last insert id will be 1, not 3.
variable.GetSessionVars(e.ctx).SetLastInsertID(uint64(recordID))
// It's used for retry.
rewriteValueCol = c
}
} else {
var err error
@ -676,6 +683,50 @@ func (e *InsertValues) initDefaultValues(row []types.Datum, marked map[int]struc
if err := column.CastValues(e.ctx, row, defaultValueCols); err != nil {
return errors.Trace(err)
}
// It's used for retry.
if rewriteValueCol == nil {
return nil
}
if len(e.Setlist) > 0 {
val := &ast.Assignment{
Column: &ast.ColumnName{Name: rewriteValueCol.Name},
Expr: ast.NewValueExpr(row[rewriteValueCol.Offset].GetValue())}
if len(e.Setlist) < rewriteValueCol.Offset+1 {
e.Setlist = append(e.Setlist, val)
return nil
}
setlist := make([]*ast.Assignment, 0, len(e.Setlist)+1)
setlist = append(setlist, e.Setlist[:rewriteValueCol.Offset]...)
setlist = append(setlist, val)
e.Setlist = append(setlist, e.Setlist[rewriteValueCol.Offset:]...)
return nil
}
// records the values of each row.
vals := make([]ast.ExprNode, len(row))
for i, col := range row {
vals[i] = ast.NewValueExpr(col.GetValue())
}
if len(e.Lists) <= e.currRow {
e.Lists = append(e.Lists, vals)
} else {
e.Lists[e.currRow] = vals
}
// records the column name only once.
if e.currRow != len(e.Lists)-1 {
return nil
}
if len(e.Columns) < rewriteValueCol.Offset+1 {
e.Columns = append(e.Columns, &ast.ColumnName{Name: rewriteValueCol.Name})
return nil
}
cols := make([]*ast.ColumnName, 0, len(e.Columns)+1)
cols = append(cols, e.Columns[:rewriteValueCol.Offset]...)
cols = append(cols, &ast.ColumnName{Name: rewriteValueCol.Name})
e.Columns = append(cols, e.Columns[rewriteValueCol.Offset:]...)
return nil
}

View File

@ -343,6 +343,106 @@ func (s *testSessionSuite) TestRowLock(c *C) {
mustExecSQL(c, se, s.dropDBSQL)
}
func (s *testSessionSuite) TestIssue827(c *C) {
store := newStore(c, s.dbName)
se := newSession(c, store, s.dbName)
se1 := newSession(c, store, s.dbName)
mustExecSQL(c, se, "drop table if exists t1")
c.Assert(se.(*session).txn, IsNil)
mustExecSQL(c, se, "create table t1 (c2 int, c3 int, c1 int not null auto_increment, PRIMARY KEY (c1))")
mustExecSQL(c, se, "insert into t1 set c2 = 30")
mustExecSQL(c, se, "drop table if exists t")
c.Assert(se.(*session).txn, IsNil)
mustExecSQL(c, se, "create table t (c2 int, c1 int not null auto_increment, PRIMARY KEY (c1))")
mustExecSQL(c, se, "insert into t (c2) values (1), (2), (3)")
// insert values
lastInsertID := se.LastInsertID()
mustExecSQL(c, se, "begin")
mustExecSQL(c, se, "insert into t (c2) values (11), (12), (13)")
rs, err := exec(c, se, "select c1 from t where c2 = 11")
c.Assert(err, IsNil)
expect, err := GetRows(rs)
c.Assert(err, IsNil)
_, err = exec(c, se, "update t set c2 = 33 where c2 = 1")
c.Assert(err, IsNil)
mustExecSQL(c, se1, "begin")
mustExecSQL(c, se1, "update t set c2 = 22 where c2 = 1")
mustExecSQL(c, se1, "commit")
_, err = exec(c, se, "commit")
c.Assert(err, IsNil)
rs, err = exec(c, se, "select c1 from t where c2 = 11")
c.Assert(err, IsNil)
r, err := GetRows(rs)
c.Assert(err, IsNil)
c.Assert(r, DeepEquals, expect)
currLastInsertID := se.LastInsertID()
c.Assert(lastInsertID+3, Equals, currLastInsertID)
// insert set
lastInsertID = se.LastInsertID()
mustExecSQL(c, se, "begin")
mustExecSQL(c, se, "insert into t set c2 = 31")
rs, err = exec(c, se, "select c1 from t where c2 = 31")
c.Assert(err, IsNil)
expect, err = GetRows(rs)
c.Assert(err, IsNil)
_, err = exec(c, se, "update t set c2 = 44 where c2 = 2")
c.Assert(err, IsNil)
mustExecSQL(c, se1, "begin")
mustExecSQL(c, se1, "update t set c2 = 55 where c2 = 2")
mustExecSQL(c, se1, "commit")
_, err = exec(c, se, "commit")
c.Assert(err, IsNil)
rs, err = exec(c, se, "select c1 from t where c2 = 31")
c.Assert(err, IsNil)
r, err = GetRows(rs)
c.Assert(err, IsNil)
c.Assert(r, DeepEquals, expect)
currLastInsertID = se.LastInsertID()
c.Assert(lastInsertID+1, Equals, currLastInsertID)
// replace
lastInsertID = se.LastInsertID()
mustExecSQL(c, se, "begin")
mustExecSQL(c, se, "insert into t (c2) values (21), (22), (23)")
rs, err = exec(c, se, "select c1 from t where c2 = 21")
c.Assert(err, IsNil)
expect, err = GetRows(rs)
c.Assert(err, IsNil)
_, err = exec(c, se, "update t set c2 = 66 where c2 = 3")
c.Assert(err, IsNil)
mustExecSQL(c, se1, "begin")
mustExecSQL(c, se1, "update t set c2 = 77 where c2 = 3")
mustExecSQL(c, se1, "commit")
_, err = exec(c, se, "commit")
c.Assert(err, IsNil)
rs, err = exec(c, se, "select c1 from t where c2 = 21")
c.Assert(err, IsNil)
r, err = GetRows(rs)
c.Assert(err, IsNil)
c.Assert(r, DeepEquals, expect)
currLastInsertID = se.LastInsertID()
c.Assert(lastInsertID+3, Equals, currLastInsertID)
mustExecSQL(c, se, s.dropDBSQL)
err = se.Close()
c.Assert(err, IsNil)
err = se1.Close()
c.Assert(err, IsNil)
}
func (s *testSessionSuite) TestSelectForUpdate(c *C) {
store := newStore(c, s.dbName)
se := newSession(c, store, s.dbName)