stmt: refactor insert
This commit is contained in:
@ -85,66 +85,6 @@ func (s *InsertIntoStmt) SetText(text string) {
|
||||
s.Text = text
|
||||
}
|
||||
|
||||
// execExecSelect implements `insert table select ... from ...`.
|
||||
func (s *InsertValues) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (rset.Recordset, error) {
|
||||
r, err := s.Sel.Plan(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
defer r.Close()
|
||||
if len(r.GetFields()) != len(cols) {
|
||||
return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields()))
|
||||
}
|
||||
|
||||
var bufRecords [][]interface{}
|
||||
var recordIDs []int64
|
||||
for {
|
||||
var row *plan.Row
|
||||
row, err = r.Next(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
if row == nil {
|
||||
break
|
||||
}
|
||||
data0 := make([]interface{}, len(t.Cols()))
|
||||
marked := make(map[int]struct{}, len(cols))
|
||||
for i, d := range row.Data {
|
||||
data0[cols[i].Offset] = d
|
||||
marked[cols[i].Offset] = struct{}{}
|
||||
}
|
||||
|
||||
var recordID int64
|
||||
if recordID, err = s.initDefaultValues(ctx, t, data0, marked); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if err = column.CastValues(ctx, data0, cols); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if err = column.CheckNotNull(t.Cols(), data0); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
var v interface{}
|
||||
v, err = types.Clone(data0)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
bufRecords = append(bufRecords, v.([]interface{}))
|
||||
recordIDs = append(recordIDs, recordID)
|
||||
}
|
||||
|
||||
for i, r := range bufRecords {
|
||||
if _, err = t.AddRecord(ctx, r, recordIDs[i]); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
variable.GetSessionVars(ctx).SetLastInsertID(uint64(recordIDs[i]))
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// There are three types of insert statements:
|
||||
// 1 insert ... values(...) --> name type column
|
||||
// 2 insert ... set x=y... --> set type column
|
||||
@ -232,39 +172,24 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
// Process `insert ... (select ..) `
|
||||
// TODO: handles the duplicate-key in a primary key or a unique index.
|
||||
if s.Sel != nil {
|
||||
return s.execSelect(t, cols, ctx)
|
||||
}
|
||||
|
||||
// Process `insert ... set x=y...`
|
||||
if err = s.fillValueList(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
m, err := s.getDefaultValues(ctx, t.Cols())
|
||||
txn, err := ctx.GetTxn(false)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
insertValueCount := len(s.Lists[0])
|
||||
toUpdateColumns, err := getOnDuplicateUpdateColumns(s.OnDuplicate, t)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
rows := make([][]interface{}, len(s.Lists))
|
||||
recordIDs := make([]int64, len(s.Lists))
|
||||
for i, list := range s.Lists {
|
||||
if err = s.checkValueCount(insertValueCount, len(list), i, cols); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
rows[i], recordIDs[i], err = s.getRow(ctx, t, cols, list, m)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
var rows [][]interface{}
|
||||
var recordIDs []int64
|
||||
if s.Sel != nil {
|
||||
rows, recordIDs, err = s.getRowsSelect(ctx, t, cols)
|
||||
} else {
|
||||
rows, recordIDs, err = s.getRows(ctx, t, cols)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if len(s.OnDuplicate) > 0 {
|
||||
@ -274,11 +199,6 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
|
||||
}
|
||||
}
|
||||
|
||||
txn, err := ctx.GetTxn(false)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
for i, row := range rows {
|
||||
// Notes: incompatible with mysql
|
||||
// MySQL will set last insert id to the first row, as follows:
|
||||
@ -364,6 +284,83 @@ func (s *InsertValues) checkValueCount(insertValueCount, valueCount, num int, co
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InsertValues) getRows(ctx context.Context, t table.Table, cols []*column.Col) (rows [][]interface{}, recordIDs []int64, err error) {
|
||||
// process `insert|replace ... set x=y...`
|
||||
if err = s.fillValueList(); err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
m, err := s.getDefaultValues(ctx, t.Cols())
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
rows = make([][]interface{}, len(s.Lists))
|
||||
recordIDs = make([]int64, len(s.Lists))
|
||||
for i, list := range s.Lists {
|
||||
if err = s.checkValueCount(len(s.Lists[0]), len(list), i, cols); err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
rows[i], recordIDs[i], err = s.getRow(ctx, t, cols, list, m)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *InsertValues) getRowsSelect(ctx context.Context, t table.Table, cols []*column.Col) (rows [][]interface{}, recordIDs []int64, err error) {
|
||||
// process `insert|replace into ... select ... from ...`
|
||||
r, err := s.Sel.Plan(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
defer r.Close()
|
||||
if len(r.GetFields()) != len(cols) {
|
||||
return nil, nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields()))
|
||||
}
|
||||
|
||||
for {
|
||||
var row *plan.Row
|
||||
row, err = r.Next(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
if row == nil {
|
||||
break
|
||||
}
|
||||
data0 := make([]interface{}, len(t.Cols()))
|
||||
marked := make(map[int]struct{}, len(cols))
|
||||
for i, d := range row.Data {
|
||||
data0[cols[i].Offset] = d
|
||||
marked[cols[i].Offset] = struct{}{}
|
||||
}
|
||||
|
||||
var recordID int64
|
||||
if recordID, err = s.initDefaultValues(ctx, t, data0, marked); err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if err = column.CastValues(ctx, data0, cols); err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if err = column.CheckNotNull(t.Cols(), data0); err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
var v interface{}
|
||||
v, err = types.Clone(data0)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
rows = append(rows, v.([]interface{}))
|
||||
recordIDs = append(recordIDs, recordID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *InsertValues) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, int64, error) {
|
||||
r := make([]interface{}, len(t.Cols()))
|
||||
marked := make(map[int]struct{}, len(list))
|
||||
|
||||
@ -66,30 +66,19 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
// Process `replace ... (select ...)`
|
||||
// TODO: handles the duplicate-key in a primary key or a unique index.
|
||||
var rows [][]interface{}
|
||||
var recordIDs []int64
|
||||
if s.Sel != nil {
|
||||
return s.execSelect(t, cols, ctx)
|
||||
rows, recordIDs, err = s.getRowsSelect(ctx, t, cols)
|
||||
} else {
|
||||
rows, recordIDs, err = s.getRows(ctx, t, cols)
|
||||
}
|
||||
// Process `replace ... set x=y ...`
|
||||
if err = s.fillValueList(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
m, err := s.getDefaultValues(ctx, t.Cols())
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
replaceValueCount := len(s.Lists[0])
|
||||
|
||||
for i, list := range s.Lists {
|
||||
if err = s.checkValueCount(replaceValueCount, len(list), i, cols); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
row, recordID, err := s.getRow(ctx, t, cols, list, m)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
h, err := t.AddRecord(ctx, row, recordID)
|
||||
for i, row := range rows {
|
||||
h, err := t.AddRecord(ctx, row, recordIDs[i])
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user