diff --git a/column/column.go b/column/column.go index b05cc4f7c6..6158584071 100644 --- a/column/column.go +++ b/column/column.go @@ -202,6 +202,11 @@ func (c *Col) CheckNotNull(data interface{}) error { return nil } +// IsPKHandleColumn checks if the column is primary key handle column. +func (c *Col) IsPKHandleColumn(tbInfo *model.TableInfo) bool { + return mysql.HasPriKeyFlag(c.Flag) && tbInfo.PKIsHandle +} + // CheckNotNull checks if row has nil value set to a column with NotNull flag set. func CheckNotNull(cols []*Col, row []interface{}) error { for _, c := range cols { diff --git a/ddl/column_test.go b/ddl/column_test.go index 9bfc35736a..e9169c50dd 100644 --- a/ddl/column_test.go +++ b/ddl/column_test.go @@ -113,7 +113,7 @@ func (s *testColumnSuite) TestColumn(c *C) { num := 10 for i := 0; i < num; i++ { - _, err := t.AddRecord(ctx, []interface{}{i, 10 * i, 100 * i}, 0) + _, err := t.AddRecord(ctx, []interface{}{i, 10 * i, 100 * i}) c.Assert(err, IsNil) } @@ -155,7 +155,7 @@ func (s *testColumnSuite) TestColumn(c *C) { }) c.Assert(i, Equals, int64(num)) - h, err := t.AddRecord(ctx, []interface{}{11, 12, 13, 14}, 0) + h, err := t.AddRecord(ctx, []interface{}{11, 12, 13, 14}) c.Assert(err, IsNil) err = ctx.FinishTxn(false) c.Assert(err, IsNil) @@ -320,7 +320,7 @@ func (s *testColumnSuite) checkDeleteOnlyColumn(c *C, ctx context.Context, d *dd c.Assert(err, IsNil) newRow := []interface{}{int64(11), int64(22), int64(33)} - handle, err = t.AddRecord(ctx, newRow, 0) + handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) txn, err = ctx.GetTxn(true) @@ -381,7 +381,7 @@ func (s *testColumnSuite) checkWriteOnlyColumn(c *C, ctx context.Context, d *ddl c.Assert(err, IsNil) newRow := []interface{}{int64(11), int64(22), int64(33)} - handle, err = t.AddRecord(ctx, newRow, 0) + handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) txn, err = ctx.GetTxn(true) @@ -440,7 +440,7 @@ func (s *testColumnSuite) checkReorganizationColumn(c *C, ctx context.Context, d c.Assert(err, IsNil) newRow := []interface{}{int64(11), int64(22), int64(33)} - handle, err = t.AddRecord(ctx, newRow, 0) + handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) txn, err = ctx.GetTxn(true) @@ -499,7 +499,7 @@ func (s *testColumnSuite) checkPublicColumn(c *C, ctx context.Context, d *ddl, t c.Assert(err, IsNil) newRow := []interface{}{int64(11), int64(22), int64(33), int64(44)} - handle, err = t.AddRecord(ctx, newRow, 0) + handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) txn, err = ctx.GetTxn(true) @@ -577,7 +577,7 @@ func (s *testColumnSuite) TestAddColumn(c *C) { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) row := []interface{}{int64(1), int64(2), int64(3)} - handle, err := t.AddRecord(ctx, row, 0) + handle, err := t.AddRecord(ctx, row) c.Assert(err, IsNil) err = ctx.FinishTxn(false) @@ -645,7 +645,7 @@ func (s *testColumnSuite) TestDropColumn(c *C) { colName := "c4" defaultColValue := int64(4) row := []interface{}{int64(1), int64(2), int64(3)} - handle, err := t.AddRecord(ctx, append(row, defaultColValue), 0) + handle, err := t.AddRecord(ctx, append(row, defaultColValue)) c.Assert(err, IsNil) err = ctx.FinishTxn(false) diff --git a/ddl/ddl.go b/ddl/ddl.go index dfe77342e2..4f0523f382 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -424,6 +424,20 @@ func (d *ddl) buildTableInfo(tableName model.CIStr, cols []*column.Col, constrai tbInfo.Columns = append(tbInfo.Columns, &v.ColumnInfo) } for _, constr := range constraints { + if constr.Tp == coldef.ConstrPrimaryKey { + if len(constr.Keys) == 1 { + key := constr.Keys[0] + col := column.FindCol(cols, key.ColumnName) + if col == nil { + return nil, errors.Errorf("No such column: %v", key) + } + switch col.Tp { + case mysql.TypeLong, mysql.TypeLonglong: + tbInfo.PKIsHandle = true + } + } + } + // 1. check if the column is exists // 2. add index indexColumns := make([]*model.IndexColumn, 0, len(constr.Keys)) diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index ea2f9d677f..4a64ec615b 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -85,7 +85,7 @@ func (ts *testSuite) TestDDL(c *C) { tb, err := sessionctx.GetDomain(ctx).InfoSchema().TableByName(tbIdent.Schema, tbIdent.Name) c.Assert(err, IsNil) c.Assert(tb, NotNil) - _, err = tb.AddRecord(ctx, []interface{}{1, "b", 2, 4}, 0) + _, err = tb.AddRecord(ctx, []interface{}{1, "b", 2, 4}) c.Assert(err, IsNil) alterStmt := statement(ctx, "alter table t add column aa int first").(*stmts.AlterTableStmt) @@ -126,9 +126,9 @@ func (ts *testSuite) TestDDL(c *C) { tb, err = sessionctx.GetDomain(ctx).InfoSchema().TableByName(tbIdent2.Schema, tbIdent2.Name) c.Assert(err, IsNil) c.Assert(tb, NotNil) - rid0, err := tb.AddRecord(ctx, []interface{}{1}, 0) + rid0, err := tb.AddRecord(ctx, []interface{}{1}) c.Assert(err, IsNil) - rid1, err := tb.AddRecord(ctx, []interface{}{2}, 0) + rid1, err := tb.AddRecord(ctx, []interface{}{2}) c.Assert(err, IsNil) alterStmt = statement(ctx, `alter table t2 add b enum("bb") first`).(*stmts.AlterTableStmt) @@ -155,7 +155,7 @@ func (ts *testSuite) TestDDL(c *C) { c.Assert(cols[0], Equals, nil) c.Assert(cols[1], BytesEquals, []byte("abc")) c.Assert(cols[2], Equals, int64(2)) - rid3, err := tb.AddRecord(ctx, []interface{}{mysql.Enum{Name: "bb", Value: 1}, "c", 3}, 0) + rid3, err := tb.AddRecord(ctx, []interface{}{mysql.Enum{Name: "bb", Value: 1}, "c", 3}) c.Assert(err, IsNil) cols, err = tb.Row(ctx, rid3) c.Assert(err, IsNil) diff --git a/ddl/index_test.go b/ddl/index_test.go index c7261880d7..ded6807a80 100644 --- a/ddl/index_test.go +++ b/ddl/index_test.go @@ -99,7 +99,7 @@ func (s *testIndexSuite) TestIndex(c *C) { num := 10 for i := 0; i < num; i++ { - _, err = t.AddRecord(ctx, []interface{}{i, i, i}, 0) + _, err = t.AddRecord(ctx, []interface{}{i, i, i}) c.Assert(err, IsNil) } @@ -122,14 +122,14 @@ func (s *testIndexSuite) TestIndex(c *C) { index := t.FindIndexByColName("c1") c.Assert(index, NotNil) - h, err := t.AddRecord(ctx, []interface{}{num + 1, 1, 1}, 0) + h, err := t.AddRecord(ctx, []interface{}{num + 1, 1, 1}) c.Assert(err, IsNil) - h1, err := t.AddRecord(ctx, []interface{}{num + 1, 1, 1}, 0) + h1, err := t.AddRecord(ctx, []interface{}{num + 1, 1, 1}) c.Assert(err, NotNil) c.Assert(h, Equals, h1) - h, err = t.AddRecord(ctx, []interface{}{1, 1, 1}, 0) + h, err = t.AddRecord(ctx, []interface{}{1, 1, 1}) c.Assert(err, NotNil) txn, err = ctx.GetTxn(true) @@ -153,7 +153,7 @@ func (s *testIndexSuite) TestIndex(c *C) { c.Assert(err, IsNil) c.Assert(exist, IsFalse) - h, err = t.AddRecord(ctx, []interface{}{1, 1, 1}, 0) + h, err = t.AddRecord(ctx, []interface{}{1, 1, 1}) c.Assert(err, IsNil) } @@ -230,7 +230,7 @@ func (s *testIndexSuite) checkDeleteOnlyIndex(c *C, ctx context.Context, d *ddl, c.Assert(err, IsNil) newRow := []interface{}{int64(11), int64(22), int64(33)} - handle, err = t.AddRecord(ctx, newRow, 0) + handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) txn, err = ctx.GetTxn(true) @@ -317,7 +317,7 @@ func (s *testIndexSuite) checkWriteOnlyIndex(c *C, ctx context.Context, d *ddl, c.Assert(err, IsNil) newRow := []interface{}{int64(11), int64(22), int64(33)} - handle, err = t.AddRecord(ctx, newRow, 0) + handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) txn, err = ctx.GetTxn(true) @@ -397,7 +397,7 @@ func (s *testIndexSuite) checkReorganizationIndex(c *C, ctx context.Context, d * c.Assert(err, IsNil) newRow := []interface{}{int64(11), int64(22), int64(33)} - handle, err = t.AddRecord(ctx, newRow, 0) + handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) txn, err = ctx.GetTxn(true) @@ -484,7 +484,7 @@ func (s *testIndexSuite) checkPublicIndex(c *C, ctx context.Context, d *ddl, tbl c.Assert(err, IsNil) newRow := []interface{}{int64(11), int64(22), int64(33)} - handle, err = t.AddRecord(ctx, newRow, 0) + handle, err = t.AddRecord(ctx, newRow) c.Assert(err, IsNil) txn, err = ctx.GetTxn(true) @@ -574,7 +574,7 @@ func (s *testIndexSuite) TestAddIndex(c *C) { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) row := []interface{}{int64(1), int64(2), int64(3)} - handle, err := t.AddRecord(ctx, row, 0) + handle, err := t.AddRecord(ctx, row) c.Assert(err, IsNil) err = ctx.FinishTxn(false) @@ -638,7 +638,7 @@ func (s *testIndexSuite) TestDropIndex(c *C) { t := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) row := []interface{}{int64(1), int64(2), int64(3)} - handle, err := t.AddRecord(ctx, row, 0) + handle, err := t.AddRecord(ctx, row) c.Assert(err, IsNil) err = ctx.FinishTxn(false) diff --git a/ddl/reorg_test.go b/ddl/reorg_test.go index 085f5c5c3c..b4fd027d97 100644 --- a/ddl/reorg_test.go +++ b/ddl/reorg_test.go @@ -136,7 +136,7 @@ func (s *testDDLSuite) TestReorgOwner(c *C) { num := 10 for i := 0; i < num; i++ { - _, err := t.AddRecord(ctx, []interface{}{i, i, i}, 0) + _, err := t.AddRecord(ctx, []interface{}{i, i, i}) c.Assert(err, IsNil) } diff --git a/ddl/table_test.go b/ddl/table_test.go index 01baba1f07..ef1cb6a114 100644 --- a/ddl/table_test.go +++ b/ddl/table_test.go @@ -174,10 +174,10 @@ func (s *testTableSuite) TestTable(c *C) { tbl := testGetTable(c, d, s.dbInfo.ID, tblInfo.ID) - _, err = tbl.AddRecord(ctx, []interface{}{1, 1, 1}, 0) + _, err = tbl.AddRecord(ctx, []interface{}{1, 1, 1}) c.Assert(err, IsNil) - _, err = tbl.AddRecord(ctx, []interface{}{2, 2, 2}, 0) + _, err = tbl.AddRecord(ctx, []interface{}{2, 2, 2}) c.Assert(err, IsNil) job = testDropTable(c, ctx, d, s.dbInfo, tblInfo) diff --git a/inspectkv/inspectkv_test.go b/inspectkv/inspectkv_test.go index f6f57f3268..28778062dc 100644 --- a/inspectkv/inspectkv_test.go +++ b/inspectkv/inspectkv_test.go @@ -157,7 +157,7 @@ func (s *testSuite) TestScan(c *C) { tb, err := tables.TableFromMeta(alloc, s.tbInfo) c.Assert(err, IsNil) indices := tb.Indices() - _, err = tb.AddRecord(s.ctx, []interface{}{10, 11}, 0) + _, err = tb.AddRecord(s.ctx, []interface{}{10, 11}) c.Assert(err, IsNil) s.ctx.FinishTxn(false) @@ -169,7 +169,7 @@ func (s *testSuite) TestScan(c *C) { c.Assert(err, IsNil) c.Assert(records, DeepEquals, []*RecordData{record1}) - _, err = tb.AddRecord(s.ctx, record2.Values, record2.Handle) + _, err = tb.AddRecord(s.ctx, record2.Values) c.Assert(err, IsNil) s.ctx.FinishTxn(false) txn, err := s.store.Begin() diff --git a/model/model.go b/model/model.go index 0fc7fecb8d..0159b541a2 100644 --- a/model/model.go +++ b/model/model.go @@ -79,9 +79,10 @@ type TableInfo struct { Charset string `json:"charset"` Collate string `json:"collate"` // Columns are listed in the order in which they appear in the schema. - Columns []*ColumnInfo `json:"cols"` - Indices []*IndexInfo `json:"index_info"` - State SchemaState `json:"state"` + Columns []*ColumnInfo `json:"cols"` + Indices []*IndexInfo `json:"index_info"` + State SchemaState `json:"state"` + PKIsHandle bool `json:"pk_is_handle"` } // Clone clones TableInfo. diff --git a/plan/plans/from_test.go b/plan/plans/from_test.go index 22d818c26c..44331da7b8 100644 --- a/plan/plans/from_test.go +++ b/plan/plans/from_test.go @@ -75,9 +75,12 @@ func (p *testFromSuit) SetUpSuite(c *C) { c.Assert(err, IsNil) p.vars = map[string]interface{}{} p.txn, _ = store.Begin() - p.cols = []*column.Col{ - { - ColumnInfo: model.ColumnInfo{ + tbInfo := &model.TableInfo{ + ID: 1, + Name: model.NewCIStr("t"), + State: model.StatePublic, + Columns: []*model.ColumnInfo{ + { ID: 0, Name: model.NewCIStr("id"), Offset: 0, @@ -85,9 +88,7 @@ func (p *testFromSuit) SetUpSuite(c *C) { FieldType: *types.NewFieldType(mysql.TypeLonglong), State: model.StatePublic, }, - }, - { - ColumnInfo: model.ColumnInfo{ + { ID: 1, Name: model.NewCIStr("name"), Offset: 1, @@ -97,14 +98,13 @@ func (p *testFromSuit) SetUpSuite(c *C) { }, }, } - - p.tbl = tables.NewTable(1, "t", p.cols, &simpleAllocator{}) - + p.tbl, err = tables.TableFromMeta(&simpleAllocator{}, tbInfo) + c.Assert(err, IsNil) variable.BindSessionVars(p) var i int64 for i = 0; i < 10; i++ { - _, err = p.tbl.AddRecord(p, []interface{}{i * 10, "hello"}, 0) + _, err = p.tbl.AddRecord(p, []interface{}{i * 10, "hello"}) c.Assert(err, IsNil) } } @@ -131,8 +131,8 @@ func (p *testFromSuit) TestTableDefaultPlan(c *C) { pln := &plans.TableDefaultPlan{ T: p.tbl, Fields: []*field.ResultField{ - field.ColToResultField(p.cols[0], "t"), - field.ColToResultField(p.cols[1], "t"), + field.ColToResultField(p.tbl.Cols()[0], "t"), + field.ColToResultField(p.tbl.Cols()[1], "t"), }, } diff --git a/plan/plans/index_test.go b/plan/plans/index_test.go index 1eda8fc304..f3b3b60e74 100644 --- a/plan/plans/index_test.go +++ b/plan/plans/index_test.go @@ -52,9 +52,12 @@ func (p *testIndexSuit) SetUpSuite(c *C) { p.store = store se, _ := tidb.CreateSession(store) p.ctx = se.(context.Context) - p.cols = []*column.Col{ - { - ColumnInfo: model.ColumnInfo{ + tbInfo := &model.TableInfo{ + ID: 2, + Name: model.NewCIStr("t2"), + State: model.StatePublic, + Columns: []*model.ColumnInfo{ + { ID: 0, Name: model.NewCIStr("id"), Offset: 0, @@ -62,9 +65,7 @@ func (p *testIndexSuit) SetUpSuite(c *C) { FieldType: *types.NewFieldType(mysql.TypeLonglong), State: model.StatePublic, }, - }, - { - ColumnInfo: model.ColumnInfo{ + { ID: 1, Name: model.NewCIStr("name"), Offset: 1, @@ -73,33 +74,29 @@ func (p *testIndexSuit) SetUpSuite(c *C) { State: model.StatePublic, }, }, - } - - p.tbl = tables.NewTable(2, "t2", p.cols, &simpleAllocator{}) - - idxCol := &column.IndexedCol{ - IndexInfo: model.IndexInfo{ - Name: model.NewCIStr("id"), - Table: model.NewCIStr("t2"), - Columns: []*model.IndexColumn{ - { - Name: model.NewCIStr("id"), - Offset: 0, - Length: 0, + Indices: []*model.IndexInfo{ + { + Name: model.NewCIStr("id"), + Table: model.NewCIStr("t2"), + Columns: []*model.IndexColumn{ + { + Name: model.NewCIStr("id"), + Offset: 0, + Length: 0, + }, }, + Unique: false, + Primary: false, + State: model.StatePublic, }, - Unique: false, - Primary: false, - State: model.StatePublic, }, } - idxCol.X = kv.NewKVIndex([]byte("i"), "id", 0, false) - - p.tbl.AddIndex(idxCol) + p.tbl, err = tables.TableFromMeta(&simpleAllocator{}, tbInfo) + c.Assert(err, IsNil) var i int64 for i = 0; i < 10; i++ { - p.tbl.AddRecord(p.ctx, []interface{}{i * 10, "hello"}, 0) + p.tbl.AddRecord(p.ctx, []interface{}{i * 10, "hello"}) } } @@ -123,8 +120,8 @@ func (p *testIndexSuit) TestIndexPlan(c *C) { pln := &plans.TableDefaultPlan{ T: p.tbl, Fields: []*field.ResultField{ - field.ColToResultField(p.cols[0], "t"), - field.ColToResultField(p.cols[1], "t"), + field.ColToResultField(p.tbl.Cols()[0], "t"), + field.ColToResultField(p.tbl.Cols()[1], "t"), }, } diff --git a/session_test.go b/session_test.go index 267f576ba3..fa890d09e6 100644 --- a/session_test.go +++ b/session_test.go @@ -1069,6 +1069,7 @@ func (s *testSessionSuite) TestIssue463(c *C) { // Testcase for https://github.com/pingcap/tidb/issues/463 store := newStore(c, s.dbName) se := newSession(c, store, s.dbName) + mustExecSQL(c, se, "DROP TABLE IF EXISTS test") mustExecSQL(c, se, `CREATE TABLE test ( id int(11) UNSIGNED NOT NULL AUTO_INCREMENT, @@ -1326,7 +1327,7 @@ func (s *testSessionSuite) TestErrorRollback(c *C) { // force generate a txn in session for later insert use. se.(*session).GetTxn(false) - se.Execute("insert into t_rollback values (1, 1, 1)") + se.Execute("insert into t_rollback values (1, 1)") _, err := se.Execute("update t_rollback set c2 = c2 + 1 where c1 = 0") c.Assert(err, IsNil) diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index a9d4a64842..5320f42124 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -182,29 +182,22 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) } var rows [][]interface{} - var recordIDs []int64 if s.Sel != nil { - rows, recordIDs, err = s.getRowsSelect(ctx, t, cols) + rows, err = s.getRowsSelect(ctx, t, cols) } else { - rows, recordIDs, err = s.getRows(ctx, t, cols) + rows, err = s.getRows(ctx, t, cols) } if err != nil { return nil, errors.Trace(err) } - for i, row := range rows { + for _, row := range rows { if len(s.OnDuplicate) == 0 { txn.SetOption(kv.PresumeKeyNotExists, nil) } - h, err := t.AddRecord(ctx, row, recordIDs[i]) + h, err := t.AddRecord(ctx, row) txn.DelOption(kv.PresumeKeyNotExists) if err == nil { - // Notes: incompatible with mysql - // MySQL will set last insert id to the first row, as follows: - // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` - // `insert t (c1) values(1),(2),(3);` - // Last insert id will be 1, not 3. - variable.GetSessionVars(ctx).SetLastInsertID(uint64(recordIDs[i])) continue } @@ -238,22 +231,21 @@ func (s *InsertValues) checkValueCount(insertValueCount, valueCount, num int, co return nil } -func (s *InsertValues) getRows(ctx context.Context, t table.Table, cols []*column.Col) (rows [][]interface{}, recordIDs []int64, err error) { +func (s *InsertValues) getRows(ctx context.Context, t table.Table, cols []*column.Col) (rows [][]interface{}, err error) { // process `insert|replace ... set x=y...` if err = s.fillValueList(); err != nil { - return nil, nil, errors.Trace(err) + return nil, errors.Trace(err) } evalMap, err := s.getColumnDefaultValues(ctx, t.Cols()) if err != nil { - return nil, nil, errors.Trace(err) + return nil, errors.Trace(err) } rows = make([][]interface{}, len(s.Lists)) - recordIDs = make([]int64, len(s.Lists)) for i, list := range s.Lists { if err = s.checkValueCount(len(s.Lists[0]), len(list), i, cols); err != nil { - return nil, nil, errors.Trace(err) + return nil, errors.Trace(err) } vals := make([]interface{}, len(list)) @@ -263,52 +255,50 @@ func (s *InsertValues) getRows(ctx context.Context, t table.Table, cols []*colum vals[j], err = expr.Eval(ctx, evalMap) if err != nil { - return nil, nil, errors.Trace(err) + return nil, errors.Trace(err) } } - rows[i], recordIDs[i], err = s.fillRowData(ctx, t, cols, vals) + rows[i], err = s.fillRowData(ctx, t, cols, vals) if err != nil { - return nil, nil, errors.Trace(err) + return nil, errors.Trace(err) } } return } -func (s *InsertValues) getRowsSelect(ctx context.Context, t table.Table, cols []*column.Col) (rows [][]interface{}, recordIDs []int64, err error) { +func (s *InsertValues) getRowsSelect(ctx context.Context, t table.Table, cols []*column.Col) (rows [][]interface{}, err error) { // process `insert|replace into ... select ... from ...` r, err := s.Sel.Plan(ctx) if err != nil { - return nil, nil, errors.Trace(err) + return nil, errors.Trace(err) } defer r.Close() if len(r.GetFields()) != len(cols) { - return nil, nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields())) + return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields())) } for { var planRow *plan.Row planRow, err = r.Next(ctx) if err != nil { - return nil, nil, errors.Trace(err) + return nil, errors.Trace(err) } if planRow == nil { break } var row []interface{} - var recordID int64 - row, recordID, err = s.fillRowData(ctx, t, cols, planRow.Data) + row, err = s.fillRowData(ctx, t, cols, planRow.Data) if err != nil { - return nil, nil, errors.Trace(err) + return nil, errors.Trace(err) } rows = append(rows, row) - recordIDs = append(recordIDs, recordID) } return } -func (s *InsertValues) fillRowData(ctx context.Context, t table.Table, cols []*column.Col, vals []interface{}) ([]interface{}, int64, error) { +func (s *InsertValues) fillRowData(ctx context.Context, t table.Table, cols []*column.Col, vals []interface{}) ([]interface{}, error) { row := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(vals)) for i, v := range vals { @@ -316,18 +306,17 @@ func (s *InsertValues) fillRowData(ctx context.Context, t table.Table, cols []*c row[offset] = v marked[offset] = struct{}{} } - recordID, err := s.initDefaultValues(ctx, t, row, marked) + err := s.initDefaultValues(ctx, t, row, marked) if err != nil { - return nil, 0, errors.Trace(err) + return nil, errors.Trace(err) } if err = column.CastValues(ctx, row, cols); err != nil { - return nil, 0, errors.Trace(err) + return nil, errors.Trace(err) } if err = column.CheckNotNull(t.Cols(), row); err != nil { - return nil, 0, errors.Trace(err) + return nil, errors.Trace(err) } - - return row, recordID, nil + return row, nil } func execOnDuplicateUpdate(ctx context.Context, t table.Table, row []interface{}, h int64, cols map[int]*expression.Assignment) error { @@ -368,7 +357,7 @@ func getOnDuplicateUpdateColumns(assignList []*expression.Assignment, t table.Ta return m, nil } -func (s *InsertValues) initDefaultValues(ctx context.Context, t table.Table, row []interface{}, marked map[int]struct{}) (recordID int64, err error) { +func (s *InsertValues) initDefaultValues(ctx context.Context, t table.Table, row []interface{}, marked map[int]struct{}) error { var defaultValueCols []*column.Col for i, c := range t.Cols() { if row[i] != nil { @@ -382,15 +371,24 @@ func (s *InsertValues) initDefaultValues(ctx context.Context, t table.Table, row } if mysql.HasAutoIncrementFlag(c.Flag) { - if recordID, err = t.AllocAutoID(); err != nil { - return 0, errors.Trace(err) + recordID, err := t.AllocAutoID() + if err != nil { + return errors.Trace(err) } row[i] = recordID + if c.IsPKHandleColumn(t.Meta()) { + // Notes: incompatible with mysql + // MySQL will set last insert id to the first row, as follows: + // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` + // `insert t (c1) values(1),(2),(3);` + // Last insert id will be 1, not 3. + variable.GetSessionVars(ctx).SetLastInsertID(uint64(recordID)) + } } else { var value interface{} - value, _, err = tables.GetColDefaultValue(ctx, &c.ColumnInfo) + value, _, err := tables.GetColDefaultValue(ctx, &c.ColumnInfo) if err != nil { - return 0, errors.Trace(err) + return errors.Trace(err) } row[i] = value @@ -399,9 +397,9 @@ func (s *InsertValues) initDefaultValues(ctx context.Context, t table.Table, row defaultValueCols = append(defaultValueCols, c) } - if err = column.CastValues(ctx, row, defaultValueCols); err != nil { - return 0, errors.Trace(err) + if err := column.CastValues(ctx, row, defaultValueCols); err != nil { + return errors.Trace(err) } - return + return nil } diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index a6018f6771..93933e1b81 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -68,18 +68,17 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error } var rows [][]interface{} - var recordIDs []int64 if s.Sel != nil { - rows, recordIDs, err = s.getRowsSelect(ctx, t, cols) + rows, err = s.getRowsSelect(ctx, t, cols) } else { - rows, recordIDs, err = s.getRows(ctx, t, cols) + rows, err = s.getRows(ctx, t, cols) } if err != nil { return nil, errors.Trace(err) } - for i, row := range rows { - h, err := t.AddRecord(ctx, row, recordIDs[i]) + for _, row := range rows { + h, err := t.AddRecord(ctx, row) if err == nil { continue } diff --git a/stmt/stmts/update.go b/stmt/stmts/update.go index 14c6887c3f..0976caac19 100644 --- a/stmt/stmts/update.go +++ b/stmt/stmts/update.go @@ -133,6 +133,7 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl copy(newData, oldData) assignExists := false + var newHandle interface{} for i, asgn := range updateColumns { if i < offset || i >= offset+len(cols) { // The assign expression is for another table, not this. @@ -145,6 +146,11 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl } colIndex := i - offset + col := cols[colIndex] + if col.IsPKHandleColumn(t.Meta()) { + newHandle = val + } + touched[colIndex] = true newData[colIndex] = val assignExists = true @@ -189,8 +195,17 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl return nil } - // Update record to new value and update index. - err := t.UpdateRecord(ctx, h, oldData, newData, touched) + var err error + if newHandle != nil { + err = t.RemoveRecord(ctx, h, oldData) + if err != nil { + return errors.Trace(err) + } + _, err = t.AddRecord(ctx, newData) + } else { + // Update record to new value and update index. + err = t.UpdateRecord(ctx, h, oldData, newData, touched) + } if err != nil { return errors.Trace(err) } diff --git a/table/table.go b/table/table.go index 5f723478e8..9ce66b1a93 100644 --- a/table/table.go +++ b/table/table.go @@ -78,8 +78,8 @@ type Table interface { // Truncate truncates the table. Truncate(rm kv.RetrieverMutator) (err error) - // AddRecord inserts a row into the table. Is h is 0, it will alloc an unique id inside. - AddRecord(ctx context.Context, r []interface{}, h int64) (recordID int64, err error) + // AddRecord inserts a row into the table. + AddRecord(ctx context.Context, r []interface{}) (recordID int64, err error) // UpdateRecord updates a row in the table. UpdateRecord(ctx context.Context, h int64, currData []interface{}, newData []interface{}, touched map[int]bool) error diff --git a/table/tables/tables.go b/table/tables/tables.go index f2ad998125..eaaff24281 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -54,7 +54,7 @@ type Table struct { recordPrefix kv.Key indexPrefix kv.Key alloc autoid.Allocator - state model.SchemaState + meta *model.TableInfo } // TableFromMeta creates a Table instance from model.TableInfo. @@ -73,7 +73,7 @@ func TableFromMeta(alloc autoid.Allocator, tblInfo *model.TableInfo) (table.Tabl columns = append(columns, col) } - t := NewTable(tblInfo.ID, tblInfo.Name.O, columns, alloc) + t := newTable(tblInfo.ID, tblInfo.Name.O, columns, alloc) for _, idxInfo := range tblInfo.Indices { if idxInfo.State == model.StateNone { @@ -88,13 +88,12 @@ func TableFromMeta(alloc autoid.Allocator, tblInfo *model.TableInfo) (table.Tabl t.AddIndex(idx) } - - t.state = tblInfo.State + t.meta = tblInfo return t, nil } // NewTable constructs a Table instance. -func NewTable(tableID int64, tableName string, cols []*column.Col, alloc autoid.Allocator) *Table { +func newTable(tableID int64, tableName string, cols []*column.Col, alloc autoid.Allocator) *Table { name := model.NewCIStr(tableName) t := &Table{ ID: tableID, @@ -103,7 +102,6 @@ func NewTable(tableID int64, tableName string, cols []*column.Col, alloc autoid. indexPrefix: genTableIndexPrefix(tableID), alloc: alloc, Columns: cols, - state: model.StatePublic, } t.publicColumns = t.Cols() @@ -133,22 +131,7 @@ func (t *Table) TableName() model.CIStr { // Meta implements table.Table Meta interface. func (t *Table) Meta() *model.TableInfo { - ti := &model.TableInfo{ - Name: t.Name, - ID: t.ID, - State: t.state, - } - // load table meta - for _, col := range t.Columns { - ti.Columns = append(ti.Columns, &col.ColumnInfo) - } - - // load table indices - for _, idx := range t.indices { - ti.Indices = append(ti.Indices, &idx.IndexInfo) - } - - return ti + return t.meta } // Cols implements table.Table Cols interface. @@ -405,11 +388,19 @@ func (t *Table) rebuildIndices(rm kv.RetrieverMutator, h int64, touched map[int] } // AddRecord implements table.Table AddRecord interface. -func (t *Table) AddRecord(ctx context.Context, r []interface{}, h int64) (recordID int64, err error) { - // Already have recordID - if h != 0 { - recordID = int64(h) - } else { +func (t *Table) AddRecord(ctx context.Context, r []interface{}) (recordID int64, err error) { + var hasRecordID bool + for _, col := range t.Cols() { + if col.IsPKHandleColumn(t.meta) { + recordID, err = types.ToInt64(r[col.Offset]) + if err != nil { + return 0, errors.Trace(err) + } + hasRecordID = true + break + } + } + if !hasRecordID { recordID, err = t.alloc.Alloc(t.ID) if err != nil { return 0, errors.Trace(err) @@ -452,6 +443,10 @@ func (t *Table) AddRecord(ctx context.Context, r []interface{}, h int64) (record // Set public and write only column value. for _, col := range t.writableCols() { + if col.IsPKHandleColumn(t.meta) { + continue + } + var value interface{} if col.State == model.StateWriteOnly || col.State == model.StateWriteReorganization { // if col is in write only or write reorganization state, we must add it with its default value. @@ -509,6 +504,10 @@ func (t *Table) RowWithCols(retriever kv.Retriever, h int64, cols []*column.Col) if col.State != model.StatePublic { return nil, errors.Errorf("Cannot use none public column - %v", cols) } + if col.IsPKHandleColumn(t.meta) { + v[col.Offset] = h + continue + } k := t.RecordKey(h, col) data, err := retriever.Get(k) diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index e48e53df2f..81dbb392e2 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -70,7 +70,7 @@ func (ts *testSuite) TestBasic(c *C) { c.Assert(err, IsNil) c.Assert(autoid, Greater, int64(0)) - rid, err := tb.AddRecord(ctx, []interface{}{1, "abc"}, 0) + rid, err := tb.AddRecord(ctx, []interface{}{1, "abc"}) c.Assert(err, IsNil) c.Assert(rid, Greater, int64(0)) row, err := tb.Row(ctx, rid) @@ -78,9 +78,9 @@ func (ts *testSuite) TestBasic(c *C) { c.Assert(len(row), Equals, 2) c.Assert(row[0].(int64), Equals, int64(1)) - _, err = tb.AddRecord(ctx, []interface{}{1, "aba"}, 0) + _, err = tb.AddRecord(ctx, []interface{}{1, "aba"}) c.Assert(err, NotNil) - _, err = tb.AddRecord(ctx, []interface{}{2, "abc"}, 0) + _, err = tb.AddRecord(ctx, []interface{}{2, "abc"}) c.Assert(err, NotNil) c.Assert(tb.UpdateRecord(ctx, rid, []interface{}{1, "abc"}, []interface{}{1, "cba"}, map[int]bool{0: false, 1: true}), IsNil) @@ -102,7 +102,7 @@ func (ts *testSuite) TestBasic(c *C) { c.Assert(tb.RemoveRecord(ctx, rid, []interface{}{1, "cba"}), IsNil) // Make sure index data is also removed after tb.RemoveRecord(). c.Assert(indexCnt(), Equals, 0) - _, err = tb.AddRecord(ctx, []interface{}{1, "abc"}, 0) + _, err = tb.AddRecord(ctx, []interface{}{1, "abc"}) c.Assert(err, IsNil) c.Assert(indexCnt(), Greater, 0) // Make sure index data is also removed after tb.Truncate(). @@ -190,9 +190,9 @@ func (ts *testSuite) TestUniqueIndexMultipleNullEntries(c *C) { c.Assert(err, IsNil) c.Assert(autoid, Greater, int64(0)) - _, err = tb.AddRecord(ctx, []interface{}{1, nil}, 0) + _, err = tb.AddRecord(ctx, []interface{}{1, nil}) c.Assert(err, IsNil) - _, err = tb.AddRecord(ctx, []interface{}{2, nil}, 0) + _, err = tb.AddRecord(ctx, []interface{}{2, nil}) c.Assert(err, IsNil) _, err = ts.se.Execute("drop table test.t") c.Assert(err, IsNil)