stmt: refactor insert

This commit is contained in:
disksing
2015-11-23 15:57:23 +08:00
parent 38336cd64a
commit 868d6afd90
2 changed files with 94 additions and 108 deletions

View File

@ -85,66 +85,6 @@ func (s *InsertIntoStmt) SetText(text string) {
s.Text = text
}
// execExecSelect implements `insert table select ... from ...`.
func (s *InsertValues) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (rset.Recordset, error) {
r, err := s.Sel.Plan(ctx)
if err != nil {
return nil, errors.Trace(err)
}
defer r.Close()
if len(r.GetFields()) != len(cols) {
return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields()))
}
var bufRecords [][]interface{}
var recordIDs []int64
for {
var row *plan.Row
row, err = r.Next(ctx)
if err != nil {
return nil, errors.Trace(err)
}
if row == nil {
break
}
data0 := make([]interface{}, len(t.Cols()))
marked := make(map[int]struct{}, len(cols))
for i, d := range row.Data {
data0[cols[i].Offset] = d
marked[cols[i].Offset] = struct{}{}
}
var recordID int64
if recordID, err = s.initDefaultValues(ctx, t, data0, marked); err != nil {
return nil, errors.Trace(err)
}
if err = column.CastValues(ctx, data0, cols); err != nil {
return nil, errors.Trace(err)
}
if err = column.CheckNotNull(t.Cols(), data0); err != nil {
return nil, errors.Trace(err)
}
var v interface{}
v, err = types.Clone(data0)
if err != nil {
return nil, errors.Trace(err)
}
bufRecords = append(bufRecords, v.([]interface{}))
recordIDs = append(recordIDs, recordID)
}
for i, r := range bufRecords {
if _, err = t.AddRecord(ctx, r, recordIDs[i]); err != nil {
return nil, errors.Trace(err)
}
variable.GetSessionVars(ctx).SetLastInsertID(uint64(recordIDs[i]))
}
return nil, nil
}
// There are three types of insert statements:
// 1 insert ... values(...) --> name type column
// 2 insert ... set x=y... --> set type column
@ -232,39 +172,24 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
if err != nil {
return nil, errors.Trace(err)
}
// Process `insert ... (select ..) `
// TODO: handles the duplicate-key in a primary key or a unique index.
if s.Sel != nil {
return s.execSelect(t, cols, ctx)
}
// Process `insert ... set x=y...`
if err = s.fillValueList(); err != nil {
return nil, errors.Trace(err)
}
m, err := s.getDefaultValues(ctx, t.Cols())
txn, err := ctx.GetTxn(false)
if err != nil {
return nil, errors.Trace(err)
}
insertValueCount := len(s.Lists[0])
toUpdateColumns, err := getOnDuplicateUpdateColumns(s.OnDuplicate, t)
if err != nil {
return nil, errors.Trace(err)
}
rows := make([][]interface{}, len(s.Lists))
recordIDs := make([]int64, len(s.Lists))
for i, list := range s.Lists {
if err = s.checkValueCount(insertValueCount, len(list), i, cols); err != nil {
return nil, errors.Trace(err)
}
rows[i], recordIDs[i], err = s.getRow(ctx, t, cols, list, m)
if err != nil {
return nil, errors.Trace(err)
}
var rows [][]interface{}
var recordIDs []int64
if s.Sel != nil {
rows, recordIDs, err = s.getRowsSelect(ctx, t, cols)
} else {
rows, recordIDs, err = s.getRows(ctx, t, cols)
}
if err != nil {
return nil, errors.Trace(err)
}
if len(s.OnDuplicate) > 0 {
@ -274,11 +199,6 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
}
}
txn, err := ctx.GetTxn(false)
if err != nil {
return nil, errors.Trace(err)
}
for i, row := range rows {
// Notes: incompatible with mysql
// MySQL will set last insert id to the first row, as follows:
@ -364,6 +284,83 @@ func (s *InsertValues) checkValueCount(insertValueCount, valueCount, num int, co
return nil
}
func (s *InsertValues) getRows(ctx context.Context, t table.Table, cols []*column.Col) (rows [][]interface{}, recordIDs []int64, err error) {
// process `insert|replace ... set x=y...`
if err = s.fillValueList(); err != nil {
return nil, nil, errors.Trace(err)
}
m, err := s.getDefaultValues(ctx, t.Cols())
if err != nil {
return nil, nil, errors.Trace(err)
}
rows = make([][]interface{}, len(s.Lists))
recordIDs = make([]int64, len(s.Lists))
for i, list := range s.Lists {
if err = s.checkValueCount(len(s.Lists[0]), len(list), i, cols); err != nil {
return nil, nil, errors.Trace(err)
}
rows[i], recordIDs[i], err = s.getRow(ctx, t, cols, list, m)
if err != nil {
return nil, nil, errors.Trace(err)
}
}
return
}
func (s *InsertValues) getRowsSelect(ctx context.Context, t table.Table, cols []*column.Col) (rows [][]interface{}, recordIDs []int64, err error) {
// process `insert|replace into ... select ... from ...`
r, err := s.Sel.Plan(ctx)
if err != nil {
return nil, nil, errors.Trace(err)
}
defer r.Close()
if len(r.GetFields()) != len(cols) {
return nil, nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields()))
}
for {
var row *plan.Row
row, err = r.Next(ctx)
if err != nil {
return nil, nil, errors.Trace(err)
}
if row == nil {
break
}
data0 := make([]interface{}, len(t.Cols()))
marked := make(map[int]struct{}, len(cols))
for i, d := range row.Data {
data0[cols[i].Offset] = d
marked[cols[i].Offset] = struct{}{}
}
var recordID int64
if recordID, err = s.initDefaultValues(ctx, t, data0, marked); err != nil {
return nil, nil, errors.Trace(err)
}
if err = column.CastValues(ctx, data0, cols); err != nil {
return nil, nil, errors.Trace(err)
}
if err = column.CheckNotNull(t.Cols(), data0); err != nil {
return nil, nil, errors.Trace(err)
}
var v interface{}
v, err = types.Clone(data0)
if err != nil {
return nil, nil, errors.Trace(err)
}
rows = append(rows, v.([]interface{}))
recordIDs = append(recordIDs, recordID)
}
return
}
func (s *InsertValues) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, int64, error) {
r := make([]interface{}, len(t.Cols()))
marked := make(map[int]struct{}, len(list))

View File

@ -66,30 +66,19 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error
return nil, errors.Trace(err)
}
// Process `replace ... (select ...)`
// TODO: handles the duplicate-key in a primary key or a unique index.
var rows [][]interface{}
var recordIDs []int64
if s.Sel != nil {
return s.execSelect(t, cols, ctx)
rows, recordIDs, err = s.getRowsSelect(ctx, t, cols)
} else {
rows, recordIDs, err = s.getRows(ctx, t, cols)
}
// Process `replace ... set x=y ...`
if err = s.fillValueList(); err != nil {
return nil, errors.Trace(err)
}
m, err := s.getDefaultValues(ctx, t.Cols())
if err != nil {
return nil, errors.Trace(err)
}
replaceValueCount := len(s.Lists[0])
for i, list := range s.Lists {
if err = s.checkValueCount(replaceValueCount, len(list), i, cols); err != nil {
return nil, errors.Trace(err)
}
row, recordID, err := s.getRow(ctx, t, cols, list, m)
if err != nil {
return nil, errors.Trace(err)
}
h, err := t.AddRecord(ctx, row, recordID)
for i, row := range rows {
h, err := t.AddRecord(ctx, row, recordIDs[i])
if err == nil {
continue
}