diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index aa86d03eb4..134a565b79 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -85,66 +85,6 @@ func (s *InsertIntoStmt) SetText(text string) { s.Text = text } -// execExecSelect implements `insert table select ... from ...`. -func (s *InsertValues) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (rset.Recordset, error) { - r, err := s.Sel.Plan(ctx) - if err != nil { - return nil, errors.Trace(err) - } - defer r.Close() - if len(r.GetFields()) != len(cols) { - return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields())) - } - - var bufRecords [][]interface{} - var recordIDs []int64 - for { - var row *plan.Row - row, err = r.Next(ctx) - if err != nil { - return nil, errors.Trace(err) - } - if row == nil { - break - } - data0 := make([]interface{}, len(t.Cols())) - marked := make(map[int]struct{}, len(cols)) - for i, d := range row.Data { - data0[cols[i].Offset] = d - marked[cols[i].Offset] = struct{}{} - } - - var recordID int64 - if recordID, err = s.initDefaultValues(ctx, t, data0, marked); err != nil { - return nil, errors.Trace(err) - } - - if err = column.CastValues(ctx, data0, cols); err != nil { - return nil, errors.Trace(err) - } - - if err = column.CheckNotNull(t.Cols(), data0); err != nil { - return nil, errors.Trace(err) - } - var v interface{} - v, err = types.Clone(data0) - if err != nil { - return nil, errors.Trace(err) - } - - bufRecords = append(bufRecords, v.([]interface{})) - recordIDs = append(recordIDs, recordID) - } - - for i, r := range bufRecords { - if _, err = t.AddRecord(ctx, r, recordIDs[i]); err != nil { - return nil, errors.Trace(err) - } - variable.GetSessionVars(ctx).SetLastInsertID(uint64(recordIDs[i])) - } - return nil, nil -} - // There are three types of insert statements: // 1 insert ... values(...) --> name type column // 2 insert ... set x=y... --> set type column @@ -232,39 +172,24 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) if err != nil { return nil, errors.Trace(err) } - - // Process `insert ... (select ..) ` - // TODO: handles the duplicate-key in a primary key or a unique index. - if s.Sel != nil { - return s.execSelect(t, cols, ctx) - } - - // Process `insert ... set x=y...` - if err = s.fillValueList(); err != nil { - return nil, errors.Trace(err) - } - - m, err := s.getDefaultValues(ctx, t.Cols()) + txn, err := ctx.GetTxn(false) if err != nil { return nil, errors.Trace(err) } - insertValueCount := len(s.Lists[0]) toUpdateColumns, err := getOnDuplicateUpdateColumns(s.OnDuplicate, t) if err != nil { return nil, errors.Trace(err) } - rows := make([][]interface{}, len(s.Lists)) - recordIDs := make([]int64, len(s.Lists)) - for i, list := range s.Lists { - if err = s.checkValueCount(insertValueCount, len(list), i, cols); err != nil { - return nil, errors.Trace(err) - } - - rows[i], recordIDs[i], err = s.getRow(ctx, t, cols, list, m) - if err != nil { - return nil, errors.Trace(err) - } + var rows [][]interface{} + var recordIDs []int64 + if s.Sel != nil { + rows, recordIDs, err = s.getRowsSelect(ctx, t, cols) + } else { + rows, recordIDs, err = s.getRows(ctx, t, cols) + } + if err != nil { + return nil, errors.Trace(err) } if len(s.OnDuplicate) > 0 { @@ -274,11 +199,6 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) } } - txn, err := ctx.GetTxn(false) - if err != nil { - return nil, errors.Trace(err) - } - for i, row := range rows { // Notes: incompatible with mysql // MySQL will set last insert id to the first row, as follows: @@ -364,6 +284,83 @@ func (s *InsertValues) checkValueCount(insertValueCount, valueCount, num int, co return nil } +func (s *InsertValues) getRows(ctx context.Context, t table.Table, cols []*column.Col) (rows [][]interface{}, recordIDs []int64, err error) { + // process `insert|replace ... set x=y...` + if err = s.fillValueList(); err != nil { + return nil, nil, errors.Trace(err) + } + + m, err := s.getDefaultValues(ctx, t.Cols()) + if err != nil { + return nil, nil, errors.Trace(err) + } + + rows = make([][]interface{}, len(s.Lists)) + recordIDs = make([]int64, len(s.Lists)) + for i, list := range s.Lists { + if err = s.checkValueCount(len(s.Lists[0]), len(list), i, cols); err != nil { + return nil, nil, errors.Trace(err) + } + + rows[i], recordIDs[i], err = s.getRow(ctx, t, cols, list, m) + if err != nil { + return nil, nil, errors.Trace(err) + } + } + return +} + +func (s *InsertValues) getRowsSelect(ctx context.Context, t table.Table, cols []*column.Col) (rows [][]interface{}, recordIDs []int64, err error) { + // process `insert|replace into ... select ... from ...` + r, err := s.Sel.Plan(ctx) + if err != nil { + return nil, nil, errors.Trace(err) + } + defer r.Close() + if len(r.GetFields()) != len(cols) { + return nil, nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields())) + } + + for { + var row *plan.Row + row, err = r.Next(ctx) + if err != nil { + return nil, nil, errors.Trace(err) + } + if row == nil { + break + } + data0 := make([]interface{}, len(t.Cols())) + marked := make(map[int]struct{}, len(cols)) + for i, d := range row.Data { + data0[cols[i].Offset] = d + marked[cols[i].Offset] = struct{}{} + } + + var recordID int64 + if recordID, err = s.initDefaultValues(ctx, t, data0, marked); err != nil { + return nil, nil, errors.Trace(err) + } + + if err = column.CastValues(ctx, data0, cols); err != nil { + return nil, nil, errors.Trace(err) + } + + if err = column.CheckNotNull(t.Cols(), data0); err != nil { + return nil, nil, errors.Trace(err) + } + var v interface{} + v, err = types.Clone(data0) + if err != nil { + return nil, nil, errors.Trace(err) + } + + rows = append(rows, v.([]interface{})) + recordIDs = append(recordIDs, recordID) + } + return +} + func (s *InsertValues) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, int64, error) { r := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(list)) diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index a4b2f0c680..50c3ebcdfe 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -66,30 +66,19 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error return nil, errors.Trace(err) } - // Process `replace ... (select ...)` - // TODO: handles the duplicate-key in a primary key or a unique index. + var rows [][]interface{} + var recordIDs []int64 if s.Sel != nil { - return s.execSelect(t, cols, ctx) + rows, recordIDs, err = s.getRowsSelect(ctx, t, cols) + } else { + rows, recordIDs, err = s.getRows(ctx, t, cols) } - // Process `replace ... set x=y ...` - if err = s.fillValueList(); err != nil { - return nil, errors.Trace(err) - } - m, err := s.getDefaultValues(ctx, t.Cols()) if err != nil { return nil, errors.Trace(err) } - replaceValueCount := len(s.Lists[0]) - for i, list := range s.Lists { - if err = s.checkValueCount(replaceValueCount, len(list), i, cols); err != nil { - return nil, errors.Trace(err) - } - row, recordID, err := s.getRow(ctx, t, cols, list, m) - if err != nil { - return nil, errors.Trace(err) - } - h, err := t.AddRecord(ctx, row, recordID) + for i, row := range rows { + h, err := t.AddRecord(ctx, row, recordIDs[i]) if err == nil { continue }