diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index af936e0b04..0a10bad63d 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -251,12 +251,23 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) if err = s.checkValueCount(insertValueCount, len(list), i, cols); err != nil { return nil, errors.Trace(err) } - row, h, err := s.addRecord(ctx, t, cols, list, m) + + row, err := s.getRow(ctx, t, cols, list, m) + if err != nil { + return nil, errors.Trace(err) + } + + // Notes: incompatible with mysql + // MySQL will set last insert id to the first row, as follows: + // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` + // `insert t (c1) values(1),(2),(3);` + // Last insert id will be 1, not 3. + h, err := t.AddRecord(ctx, row) if err == nil { continue } - if h == -1 || len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { + if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { return nil, errors.Trace(err) } if err = execOnDuplicateUpdate(ctx, t, row, h, toUpdateColumns); err != nil { @@ -286,7 +297,7 @@ func (s *InsertIntoStmt) checkValueCount(insertValueCount, valueCount, num int, return nil } -func (s *InsertIntoStmt) addRecord(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, int64, error) { +func (s *InsertIntoStmt) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, error) { r := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(list)) for i, expr := range list { @@ -295,7 +306,7 @@ func (s *InsertIntoStmt) addRecord(ctx context.Context, t table.Table, cols []*c val, evalErr := expr.Eval(ctx, m) if evalErr != nil { - return nil, -1, errors.Trace(evalErr) + return nil, errors.Trace(evalErr) } r[cols[i].Offset] = val marked[cols[i].Offset] = struct{}{} @@ -306,23 +317,16 @@ func (s *InsertIntoStmt) addRecord(ctx context.Context, t table.Table, cols []*c err := s.initDefaultValues(ctx, t, r, marked) if err != nil { - return nil, -1, errors.Trace(err) + return nil, errors.Trace(err) } if err = column.CastValues(ctx, r, cols); err != nil { - return nil, -1, errors.Trace(err) + return nil, errors.Trace(err) } if err = column.CheckNotNull(t.Cols(), r); err != nil { - return nil, -1, errors.Trace(err) + return nil, errors.Trace(err) } - // Notes: incompatible with mysql - // MySQL will set last insert id to the first row, as follows: - // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` - // `insert t (c1) values(1),(2),(3);` - // Last insert id will be 1, not 3. - h, err := t.AddRecord(ctx, r) - - return r, h, err + return r, err } func execOnDuplicateUpdate(ctx context.Context, t table.Table, row []interface{}, h int64, cols map[int]expression.Assignment) error {