stmt: adjust the code format
This commit is contained in:
@ -109,7 +109,7 @@ func (s *InsertIntoStmt) execSelect(t table.Table, cols []*column.Col, ctx conte
|
||||
marked[cols[i].Offset] = struct{}{}
|
||||
}
|
||||
|
||||
if err = s.initDefaultValues(ctx, t, t.Cols(), data0, marked); err != nil {
|
||||
if err = s.initDefaultValues(ctx, t, data0, marked); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -185,15 +185,44 @@ func (s *InsertIntoStmt) getColumns(tableCols []*column.Col) ([]*column.Col, err
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
func (s *InsertIntoStmt) getDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) {
|
||||
m := map[interface{}]interface{}{}
|
||||
for _, v := range cols {
|
||||
if value, ok, err := getDefaultValue(ctx, v); ok {
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
m[v.Name.L] = value
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (s *InsertIntoStmt) getSetList() error {
|
||||
if len(s.Setlist) > 0 {
|
||||
if len(s.Lists) > 0 {
|
||||
return errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent)
|
||||
}
|
||||
|
||||
var l []expression.Expression
|
||||
for _, v := range s.Setlist {
|
||||
l = append(l, v.Expr)
|
||||
}
|
||||
s.Lists = append(s.Lists, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec implements the stmt.Statement Exec interface.
|
||||
func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
|
||||
t, err := getTable(ctx, s.TableIdent)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
tableCols := t.Cols()
|
||||
cols, err := s.getColumns(tableCols)
|
||||
cols, err := s.getColumns(t.Cols())
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
@ -204,118 +233,33 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
|
||||
}
|
||||
|
||||
// Process `insert ... set x=y...`
|
||||
if len(s.Setlist) > 0 {
|
||||
if len(s.Lists) > 0 {
|
||||
return nil, errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent)
|
||||
}
|
||||
|
||||
var l []expression.Expression
|
||||
for _, v := range s.Setlist {
|
||||
l = append(l, v.Expr)
|
||||
}
|
||||
s.Lists = append(s.Lists, l)
|
||||
if err = s.getSetList(); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
m := map[interface{}]interface{}{}
|
||||
for _, v := range tableCols {
|
||||
var (
|
||||
value interface{}
|
||||
ok bool
|
||||
)
|
||||
value, ok, err = getDefaultValue(ctx, v)
|
||||
if ok {
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
m[v.Name.L] = value
|
||||
}
|
||||
m, err := s.getDefaultValues(ctx, t.Cols())
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
insertValueCount := len(s.Lists[0])
|
||||
toUpdateColumns, err0 := getOnDuplicateUpdateColumns(s.OnDuplicate, t)
|
||||
if err0 != nil {
|
||||
return nil, errors.Trace(err0)
|
||||
toUpdateColumns, err := getOnDuplicateUpdateColumns(s.OnDuplicate, t)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
toUpdateArgs := map[interface{}]interface{}{}
|
||||
for i, list := range s.Lists {
|
||||
r := make([]interface{}, len(tableCols))
|
||||
valueCount := len(list)
|
||||
|
||||
if insertValueCount != valueCount {
|
||||
// "insert into t values (), ()" is valid.
|
||||
// "insert into t values (), (1)" is not valid.
|
||||
// "insert into t values (1), ()" is not valid.
|
||||
// "insert into t values (1,2), (1)" is not valid.
|
||||
// So the value count must be same for all insert list.
|
||||
return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1)
|
||||
}
|
||||
|
||||
if valueCount == 0 && len(s.ColNames) > 0 {
|
||||
// "insert into t (c1) values ()" is not valid.
|
||||
return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0)
|
||||
} else if valueCount > 0 && valueCount != len(cols) {
|
||||
return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount)
|
||||
}
|
||||
|
||||
// Clear last insert id.
|
||||
variable.GetSessionVars(ctx).SetLastInsertID(0)
|
||||
|
||||
marked := make(map[int]struct{}, len(list))
|
||||
for i, expr := range list {
|
||||
// For "insert into t values (default)" Default Eval.
|
||||
m[expression.ExprEvalDefaultName] = cols[i].Name.O
|
||||
|
||||
val, evalErr := expr.Eval(ctx, m)
|
||||
if evalErr != nil {
|
||||
return nil, errors.Trace(evalErr)
|
||||
}
|
||||
r[cols[i].Offset] = val
|
||||
marked[cols[i].Offset] = struct{}{}
|
||||
}
|
||||
|
||||
if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil {
|
||||
if err = s.checkValueCount(insertValueCount, len(list), i, cols); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if err = column.CastValues(ctx, r, cols); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
if err = column.CheckNotNull(tableCols, r); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
// Notes: incompatible with mysql
|
||||
// MySQL will set last insert id to the first row, as follows:
|
||||
// `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))`
|
||||
// `insert t (c1) values(1),(2),(3);`
|
||||
// Last insert id will be 1, not 3.
|
||||
h, err := t.AddRecord(ctx, r)
|
||||
row, h, err := s.addRecord(ctx, t, cols, list, m)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) {
|
||||
|
||||
if h == -1 || len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
// On duplicate key Update the duplicate row.
|
||||
// Evaluate the updated value.
|
||||
// TODO: report rows affected and last insert id.
|
||||
data, err := t.Row(ctx, h)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
toUpdateArgs[expression.ExprEvalValuesFunc] = func(name string) (interface{}, error) {
|
||||
c, err1 := findColumnByName(t, name)
|
||||
if err1 != nil {
|
||||
return nil, errors.Trace(err1)
|
||||
}
|
||||
return r[c.Offset], nil
|
||||
}
|
||||
|
||||
err = updateRecord(ctx, h, data, t, toUpdateColumns, toUpdateArgs, 0, true)
|
||||
if err != nil {
|
||||
if err = execOnDuplicateUpdate(ctx, t, row, h, toUpdateColumns); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
@ -323,6 +267,89 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *InsertIntoStmt) checkValueCount(insertValueCount, valueCount, num int, cols []*column.Col) error {
|
||||
if insertValueCount != valueCount {
|
||||
// "insert into t values (), ()" is valid.
|
||||
// "insert into t values (), (1)" is not valid.
|
||||
// "insert into t values (1), ()" is not valid.
|
||||
// "insert into t values (1,2), (1)" is not valid.
|
||||
// So the value count must be same for all insert list.
|
||||
return errors.Errorf("Column count doesn't match value count at row %d", num+1)
|
||||
}
|
||||
if valueCount == 0 && len(s.ColNames) > 0 {
|
||||
// "insert into t (c1) values ()" is not valid.
|
||||
return errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0)
|
||||
} else if valueCount > 0 && valueCount != len(cols) {
|
||||
return errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InsertIntoStmt) addRecord(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, int64, error) {
|
||||
r := make([]interface{}, len(t.Cols()))
|
||||
marked := make(map[int]struct{}, len(list))
|
||||
for i, expr := range list {
|
||||
// For "insert into t values (default)" Default Eval.
|
||||
m[expression.ExprEvalDefaultName] = cols[i].Name.O
|
||||
|
||||
val, evalErr := expr.Eval(ctx, m)
|
||||
if evalErr != nil {
|
||||
return nil, -1, errors.Trace(evalErr)
|
||||
}
|
||||
r[cols[i].Offset] = val
|
||||
marked[cols[i].Offset] = struct{}{}
|
||||
}
|
||||
|
||||
// Clear last insert id.
|
||||
variable.GetSessionVars(ctx).SetLastInsertID(0)
|
||||
|
||||
err := s.initDefaultValues(ctx, t, r, marked)
|
||||
if err != nil {
|
||||
return nil, -1, errors.Trace(err)
|
||||
}
|
||||
if err = column.CastValues(ctx, r, cols); err != nil {
|
||||
return nil, -1, errors.Trace(err)
|
||||
}
|
||||
if err = column.CheckNotNull(t.Cols(), r); err != nil {
|
||||
return nil, -1, errors.Trace(err)
|
||||
}
|
||||
|
||||
// Notes: incompatible with mysql
|
||||
// MySQL will set last insert id to the first row, as follows:
|
||||
// `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))`
|
||||
// `insert t (c1) values(1),(2),(3);`
|
||||
// Last insert id will be 1, not 3.
|
||||
h, err := t.AddRecord(ctx, r)
|
||||
|
||||
return r, h, err
|
||||
}
|
||||
|
||||
func execOnDuplicateUpdate(ctx context.Context, t table.Table, row []interface{}, h int64, cols map[int]expression.Assignment) error {
|
||||
// On duplicate key update the duplicate row.
|
||||
// Evaluate the updated value.
|
||||
// TODO: report rows affected and last insert id.
|
||||
data, err := t.Row(ctx, h)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
toUpdateArgs := map[interface{}]interface{}{}
|
||||
toUpdateArgs[expression.ExprEvalValuesFunc] = func(name string) (interface{}, error) {
|
||||
c, err1 := findColumnByName(t, name)
|
||||
if err1 != nil {
|
||||
return nil, errors.Trace(err1)
|
||||
}
|
||||
return row[c.Offset], nil
|
||||
}
|
||||
|
||||
if err = updateRecord(ctx, h, data, t, cols, toUpdateArgs, 0, true); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getOnDuplicateUpdateColumns(assignList []expression.Assignment, t table.Table) (map[int]expression.Assignment, error) {
|
||||
m := make(map[int]expression.Assignment, len(assignList))
|
||||
|
||||
@ -336,10 +363,10 @@ func getOnDuplicateUpdateColumns(assignList []expression.Assignment, t table.Tab
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, cols []*column.Col, row []interface{}, marked map[int]struct{}) error {
|
||||
func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, row []interface{}, marked map[int]struct{}) error {
|
||||
var err error
|
||||
var defaultValueCols []*column.Col
|
||||
for i, c := range cols {
|
||||
for i, c := range t.Cols() {
|
||||
if row[i] != nil {
|
||||
// Column value is not nil, continue.
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user