From b9f0bbe088296edfea90bcae8cd69daec22a4759 Mon Sep 17 00:00:00 2001 From: xia Date: Thu, 22 Oct 2015 17:59:12 +0800 Subject: [PATCH] stmt: adjust the code format --- stmt/stmts/insert.go | 235 ++++++++++++++++++++++++------------------- 1 file changed, 131 insertions(+), 104 deletions(-) diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index 61251f56f8..af936e0b04 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -109,7 +109,7 @@ func (s *InsertIntoStmt) execSelect(t table.Table, cols []*column.Col, ctx conte marked[cols[i].Offset] = struct{}{} } - if err = s.initDefaultValues(ctx, t, t.Cols(), data0, marked); err != nil { + if err = s.initDefaultValues(ctx, t, data0, marked); err != nil { return nil, errors.Trace(err) } @@ -185,15 +185,44 @@ func (s *InsertIntoStmt) getColumns(tableCols []*column.Col) ([]*column.Col, err return cols, nil } +func (s *InsertIntoStmt) getDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) { + m := map[interface{}]interface{}{} + for _, v := range cols { + if value, ok, err := getDefaultValue(ctx, v); ok { + if err != nil { + return nil, errors.Trace(err) + } + + m[v.Name.L] = value + } + } + + return m, nil +} + +func (s *InsertIntoStmt) getSetList() error { + if len(s.Setlist) > 0 { + if len(s.Lists) > 0 { + return errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent) + } + + var l []expression.Expression + for _, v := range s.Setlist { + l = append(l, v.Expr) + } + s.Lists = append(s.Lists, l) + } + + return nil +} + // Exec implements the stmt.Statement Exec interface. func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { t, err := getTable(ctx, s.TableIdent) if err != nil { return nil, errors.Trace(err) } - - tableCols := t.Cols() - cols, err := s.getColumns(tableCols) + cols, err := s.getColumns(t.Cols()) if err != nil { return nil, errors.Trace(err) } @@ -204,118 +233,33 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) } // Process `insert ... set x=y...` - if len(s.Setlist) > 0 { - if len(s.Lists) > 0 { - return nil, errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent) - } - - var l []expression.Expression - for _, v := range s.Setlist { - l = append(l, v.Expr) - } - s.Lists = append(s.Lists, l) + if err = s.getSetList(); err != nil { + return nil, errors.Trace(err) } - m := map[interface{}]interface{}{} - for _, v := range tableCols { - var ( - value interface{} - ok bool - ) - value, ok, err = getDefaultValue(ctx, v) - if ok { - if err != nil { - return nil, errors.Trace(err) - } - - m[v.Name.L] = value - } + m, err := s.getDefaultValues(ctx, t.Cols()) + if err != nil { + return nil, errors.Trace(err) } - insertValueCount := len(s.Lists[0]) - toUpdateColumns, err0 := getOnDuplicateUpdateColumns(s.OnDuplicate, t) - if err0 != nil { - return nil, errors.Trace(err0) + toUpdateColumns, err := getOnDuplicateUpdateColumns(s.OnDuplicate, t) + if err != nil { + return nil, errors.Trace(err) } - toUpdateArgs := map[interface{}]interface{}{} for i, list := range s.Lists { - r := make([]interface{}, len(tableCols)) - valueCount := len(list) - - if insertValueCount != valueCount { - // "insert into t values (), ()" is valid. - // "insert into t values (), (1)" is not valid. - // "insert into t values (1), ()" is not valid. - // "insert into t values (1,2), (1)" is not valid. - // So the value count must be same for all insert list. - return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1) - } - - if valueCount == 0 && len(s.ColNames) > 0 { - // "insert into t (c1) values ()" is not valid. - return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0) - } else if valueCount > 0 && valueCount != len(cols) { - return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount) - } - - // Clear last insert id. - variable.GetSessionVars(ctx).SetLastInsertID(0) - - marked := make(map[int]struct{}, len(list)) - for i, expr := range list { - // For "insert into t values (default)" Default Eval. - m[expression.ExprEvalDefaultName] = cols[i].Name.O - - val, evalErr := expr.Eval(ctx, m) - if evalErr != nil { - return nil, errors.Trace(evalErr) - } - r[cols[i].Offset] = val - marked[cols[i].Offset] = struct{}{} - } - - if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil { + if err = s.checkValueCount(insertValueCount, len(list), i, cols); err != nil { return nil, errors.Trace(err) } - - if err = column.CastValues(ctx, r, cols); err != nil { - return nil, errors.Trace(err) - } - if err = column.CheckNotNull(tableCols, r); err != nil { - return nil, errors.Trace(err) - } - - // Notes: incompatible with mysql - // MySQL will set last insert id to the first row, as follows: - // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` - // `insert t (c1) values(1),(2),(3);` - // Last insert id will be 1, not 3. - h, err := t.AddRecord(ctx, r) + row, h, err := s.addRecord(ctx, t, cols, list, m) if err == nil { continue } - if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { + + if h == -1 || len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { return nil, errors.Trace(err) } - // On duplicate key Update the duplicate row. - // Evaluate the updated value. - // TODO: report rows affected and last insert id. - data, err := t.Row(ctx, h) - if err != nil { - return nil, errors.Trace(err) - } - - toUpdateArgs[expression.ExprEvalValuesFunc] = func(name string) (interface{}, error) { - c, err1 := findColumnByName(t, name) - if err1 != nil { - return nil, errors.Trace(err1) - } - return r[c.Offset], nil - } - - err = updateRecord(ctx, h, data, t, toUpdateColumns, toUpdateArgs, 0, true) - if err != nil { + if err = execOnDuplicateUpdate(ctx, t, row, h, toUpdateColumns); err != nil { return nil, errors.Trace(err) } } @@ -323,6 +267,89 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) return nil, nil } +func (s *InsertIntoStmt) checkValueCount(insertValueCount, valueCount, num int, cols []*column.Col) error { + if insertValueCount != valueCount { + // "insert into t values (), ()" is valid. + // "insert into t values (), (1)" is not valid. + // "insert into t values (1), ()" is not valid. + // "insert into t values (1,2), (1)" is not valid. + // So the value count must be same for all insert list. + return errors.Errorf("Column count doesn't match value count at row %d", num+1) + } + if valueCount == 0 && len(s.ColNames) > 0 { + // "insert into t (c1) values ()" is not valid. + return errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0) + } else if valueCount > 0 && valueCount != len(cols) { + return errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount) + } + + return nil +} + +func (s *InsertIntoStmt) addRecord(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, int64, error) { + r := make([]interface{}, len(t.Cols())) + marked := make(map[int]struct{}, len(list)) + for i, expr := range list { + // For "insert into t values (default)" Default Eval. + m[expression.ExprEvalDefaultName] = cols[i].Name.O + + val, evalErr := expr.Eval(ctx, m) + if evalErr != nil { + return nil, -1, errors.Trace(evalErr) + } + r[cols[i].Offset] = val + marked[cols[i].Offset] = struct{}{} + } + + // Clear last insert id. + variable.GetSessionVars(ctx).SetLastInsertID(0) + + err := s.initDefaultValues(ctx, t, r, marked) + if err != nil { + return nil, -1, errors.Trace(err) + } + if err = column.CastValues(ctx, r, cols); err != nil { + return nil, -1, errors.Trace(err) + } + if err = column.CheckNotNull(t.Cols(), r); err != nil { + return nil, -1, errors.Trace(err) + } + + // Notes: incompatible with mysql + // MySQL will set last insert id to the first row, as follows: + // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` + // `insert t (c1) values(1),(2),(3);` + // Last insert id will be 1, not 3. + h, err := t.AddRecord(ctx, r) + + return r, h, err +} + +func execOnDuplicateUpdate(ctx context.Context, t table.Table, row []interface{}, h int64, cols map[int]expression.Assignment) error { + // On duplicate key update the duplicate row. + // Evaluate the updated value. + // TODO: report rows affected and last insert id. + data, err := t.Row(ctx, h) + if err != nil { + return errors.Trace(err) + } + + toUpdateArgs := map[interface{}]interface{}{} + toUpdateArgs[expression.ExprEvalValuesFunc] = func(name string) (interface{}, error) { + c, err1 := findColumnByName(t, name) + if err1 != nil { + return nil, errors.Trace(err1) + } + return row[c.Offset], nil + } + + if err = updateRecord(ctx, h, data, t, cols, toUpdateArgs, 0, true); err != nil { + return errors.Trace(err) + } + + return nil +} + func getOnDuplicateUpdateColumns(assignList []expression.Assignment, t table.Table) (map[int]expression.Assignment, error) { m := make(map[int]expression.Assignment, len(assignList)) @@ -336,10 +363,10 @@ func getOnDuplicateUpdateColumns(assignList []expression.Assignment, t table.Tab return m, nil } -func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, cols []*column.Col, row []interface{}, marked map[int]struct{}) error { +func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, row []interface{}, marked map[int]struct{}) error { var err error var defaultValueCols []*column.Col - for i, c := range cols { + for i, c := range t.Cols() { if row[i] != nil { // Column value is not nil, continue. continue