diff --git a/column/column.go b/column/column.go index d6831c8f97..75581e0ff1 100644 --- a/column/column.go +++ b/column/column.go @@ -55,10 +55,14 @@ func (c *Col) String() string { } // FindCol finds column in cols by name. -func FindCol(cols []*Col, name string) (c *Col) { - for _, c = range cols { - if strings.EqualFold(c.Name.O, name) { - return +func FindCol(cols []*Col, name string) *Col { + for _, col := range cols { + if col.State != model.StatePublic { + continue + } + + if strings.EqualFold(col.Name.O, name) { + return col } } return nil @@ -82,9 +86,13 @@ func FindCols(cols []*Col, names []string) ([]*Col, error) { // FindOnUpdateCols finds columns which have OnUpdateNow flag. func FindOnUpdateCols(cols []*Col) []*Col { var rcols []*Col - for _, c := range cols { - if mysql.HasOnUpdateNowFlag(c.Flag) { - rcols = append(rcols, c) + for _, col := range cols { + if col.State != model.StatePublic { + continue + } + + if mysql.HasOnUpdateNowFlag(col.Flag) { + rcols = append(rcols, col) } } @@ -181,8 +189,12 @@ func ColDescFieldNames(full bool) []string { // CheckOnce checks if there are duplicated column names in cols. func CheckOnce(cols []*Col) error { m := map[string]struct{}{} - for _, v := range cols { - name := v.Name + for _, col := range cols { + if col.State != model.StatePublic { + continue + } + + name := col.Name _, ok := m[name.L] if ok { return errors.Errorf("column specified twice - %s", name) diff --git a/column/column_test.go b/column/column_test.go index 45710b588b..618acf0ecc 100644 --- a/column/column_test.go +++ b/column/column_test.go @@ -34,6 +34,7 @@ func (s *testColumnSuite) TestString(c *C) { col := &Col{ model.ColumnInfo{ FieldType: *types.NewFieldType(mysql.TypeTiny), + State: model.StatePublic, }, } col.Flen = 2 @@ -109,7 +110,8 @@ func (s *testColumnSuite) TestDesc(c *C) { func newCol(name string) *Col { return &Col{ model.ColumnInfo{ - Name: model.NewCIStr(name), + Name: model.NewCIStr(name), + State: model.StatePublic, }, } } diff --git a/ddl/column.go b/ddl/column.go new file mode 100644 index 0000000000..9b3ad9b1b2 --- /dev/null +++ b/ddl/column.go @@ -0,0 +1,277 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "sync/atomic" + + "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/pingcap/tidb/column" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/util/errors2" +) + +func (d *ddl) adjustColumnOffset(columns []*model.ColumnInfo, indices []*model.IndexInfo, offset int) { + offsetChanged := make(map[int]int) + for i := offset; i < len(columns); i++ { + offsetChanged[columns[i].Offset] = i + columns[i].Offset = i + } + + // Update index column offset info. + for _, idx := range indices { + for _, col := range idx.Columns { + newOffset, ok := offsetChanged[col.Offset] + if ok { + col.Offset = newOffset + } + } + } +} + +func (d *ddl) addColumn(tblInfo *model.TableInfo, spec *AlterSpecification) (*model.ColumnInfo, int, error) { + // Check column name duplicate. + cols := tblInfo.Columns + position := len(cols) + + // Get column position. + if spec.Position.Type == ColumnPositionFirst { + position = 0 + } else if spec.Position.Type == ColumnPositionAfter { + c := findCol(cols, spec.Position.RelativeColumn) + if c == nil { + return nil, 0, errors.Errorf("No such column: %v", spec.Column.Name) + } + + // Insert position is after the mentioned column. + position = c.Offset + 1 + } + + // TODO: set constraint + col, _, err := d.buildColumnAndConstraint(position, spec.Column) + if err != nil { + return nil, 0, errors.Trace(err) + } + + colInfo := &col.ColumnInfo + colInfo.State = model.StateNone + // To support add column asynchronous, we should mark its offset as the last column. + // So that we can use origin column offset to get value from row. + colInfo.Offset = len(cols) + + // Insert col into the right place of the column list. + newCols := make([]*model.ColumnInfo, 0, len(cols)+1) + newCols = append(newCols, cols[:position]...) + newCols = append(newCols, colInfo) + newCols = append(newCols, cols[position:]...) + + tblInfo.Columns = newCols + return colInfo, position, nil +} + +func (d *ddl) onAddColumn(t *meta.Meta, job *model.Job) error { + schemaID := job.SchemaID + tblInfo, err := d.getTableInfo(t, job) + if err != nil { + return errors.Trace(err) + } + + spec := &AlterSpecification{} + offset := 0 + err = job.DecodeArgs(&spec, &offset) + if err != nil { + job.State = model.JobCancelled + return errors.Trace(err) + } + + columnInfo := findCol(tblInfo.Columns, spec.Column.Name) + if columnInfo != nil { + if columnInfo.State == model.StatePublic { + // we already have a column with same column name + job.State = model.JobCancelled + return errors.Errorf("ADD COLUMN: column already exist %s", spec.Column.Name) + } + } else { + columnInfo, offset, err = d.addColumn(tblInfo, spec) + if err != nil { + job.State = model.JobCancelled + return errors.Trace(err) + } + + // Set offset arg to job. + if offset != 0 { + job.Args = []interface{}{spec, offset} + } + } + + _, err = t.GenSchemaVersion() + if err != nil { + return errors.Trace(err) + } + + switch columnInfo.State { + case model.StateNone: + // none -> delete only + job.SchemaState = model.StateDeleteOnly + columnInfo.State = model.StateDeleteOnly + err = t.UpdateTable(schemaID, tblInfo) + return errors.Trace(err) + case model.StateDeleteOnly: + // delete only -> write only + job.SchemaState = model.StateWriteOnly + columnInfo.State = model.StateWriteOnly + err = t.UpdateTable(schemaID, tblInfo) + return errors.Trace(err) + case model.StateWriteOnly: + // write only -> reorganization + job.SchemaState = model.StateReorganization + columnInfo.State = model.StateReorganization + // initialize SnapshotVer to 0 for later reorganization check. + job.SnapshotVer = 0 + // initialize reorg handle to 0 + job.ReorgHandle = 0 + atomic.StoreInt64(&d.reorgHandle, 0) + err = t.UpdateTable(schemaID, tblInfo) + return errors.Trace(err) + case model.StateReorganization: + // reorganization -> public + // get the current version for reorganization if we don't have + if job.SnapshotVer == 0 { + var ver kv.Version + ver, err = d.store.CurrentVersion() + if err != nil { + return errors.Trace(err) + } + + job.SnapshotVer = ver.Ver + } + + tbl, err := d.getTable(t, schemaID, tblInfo) + if err != nil { + return errors.Trace(err) + } + + err = d.runReorgJob(func() error { + return d.backfillColumn(tbl, columnInfo, job.SnapshotVer, job.ReorgHandle) + }) + + // backfillColumn updates ReorgHandle after one batch. + // so we update the job ReorgHandle here. + job.ReorgHandle = atomic.LoadInt64(&d.reorgHandle) + + if errors2.ErrorEqual(err, errWaitReorgTimeout) { + // if timeout, we should return, check for the owner and re-wait job done. + return nil + } + if err != nil { + return errors.Trace(err) + } + + // Adjust column offset. + d.adjustColumnOffset(tblInfo.Columns, tblInfo.Indices, offset) + + columnInfo.State = model.StatePublic + + if err = t.UpdateTable(schemaID, tblInfo); err != nil { + return errors.Trace(err) + } + + // finish this job + job.SchemaState = model.StatePublic + job.State = model.JobDone + return nil + default: + return errors.Errorf("invalid column state %v", columnInfo.State) + } +} + +func (d *ddl) onDropColumn(t *meta.Meta, job *model.Job) error { + // TODO: complete it. + return nil +} + +func (d *ddl) backfillColumn(t table.Table, columnInfo *model.ColumnInfo, version uint64, seekHandle int64) error { + for { + handles, err := d.getSnapshotRows(t, version, seekHandle) + if err != nil { + return errors.Trace(err) + } else if len(handles) == 0 { + return nil + } + + seekHandle = handles[len(handles)-1] + 1 + // TODO: save seekHandle in reorganization job, so we can resume this job later from this handle. + + err = d.backfillColumnData(t, columnInfo, handles) + if err != nil { + return errors.Trace(err) + } + + // update reorgHandle here after every successful batch. + atomic.StoreInt64(&d.reorgHandle, seekHandle) + } +} + +func (d *ddl) backfillColumnData(t table.Table, columnInfo *model.ColumnInfo, handles []int64) error { + for _, handle := range handles { + log.Info("backfill column...", handle) + + err := kv.RunInNewTxn(d.store, true, func(txn kv.Transaction) error { + // First check if row exists. + exist, err := checkRowExist(txn, t, handle) + if err != nil { + return errors.Trace(err) + } else if !exist { + // If row doesn't exist, skip it. + return nil + } + + backfillKey := t.RecordKey(handle, &column.Col{ColumnInfo: *columnInfo}) + _, err = txn.Get(backfillKey) + if err != nil && !kv.IsErrNotFound(err) { + return errors.Trace(err) + } + + // If row column doesn't exist, we need to backfill column. + // Lock row first. + err = txn.LockKeys(t.RecordKey(handle, nil)) + if err != nil { + return errors.Trace(err) + } + + value, _, err := tables.GetColDefaultValue(nil, columnInfo) + if err != nil { + return errors.Trace(err) + } + + err = t.SetColValue(txn, backfillKey, value) + if err != nil { + return errors.Trace(err) + } + + return nil + }) + + if err != nil { + return errors.Trace(err) + } + } + + return nil +} diff --git a/ddl/ddl.go b/ddl/ddl.go index 050030ee53..bdcb4176f0 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -34,8 +34,6 @@ import ( "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/table/tables" - "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/charset" qerror "github.com/pingcap/tidb/util/errors" "github.com/twinj/uuid" @@ -182,20 +180,6 @@ func (d *ddl) CreateSchema(ctx context.Context, schema model.CIStr) (err error) return errors.Trace(err) } -func (d *ddl) verifySchemaMetaVersion(t *meta.Meta, schemaMetaVersion int64) error { - curVer, err := t.GetSchemaVersion() - if err != nil { - return errors.Trace(err) - } - if curVer != schemaMetaVersion { - return errors.Errorf("Schema changed, our version %d, but got %d", schemaMetaVersion, curVer) - } - - // Increment version. - _, err = t.GenSchemaVersion() - return errors.Trace(err) -} - func (d *ddl) DropSchema(ctx context.Context, schema model.CIStr) (err error) { is := d.GetInformationSchema() old, ok := is.SchemaByName(schema) @@ -288,7 +272,7 @@ func (d *ddl) buildColumnsAndConstraints(colDefs []*coldef.ColumnDef, constraint } func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*column.Col, []*coldef.TableConstraint, error) { - // set charset + // Set charset. if len(colDef.Tp.Charset) == 0 { switch colDef.Tp.Tp { case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: @@ -298,15 +282,17 @@ func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*c colDef.Tp.Collate = charset.CharsetBin } } - // convert colDef into col + col, cts, err := coldef.ColumnDefToCol(offset, colDef) if err != nil { return nil, nil, errors.Trace(err) } + col.ID, err = d.genGlobalID() if err != nil { return nil, nil, errors.Trace(err) } + return col, cts, nil } @@ -439,19 +425,6 @@ func (d *ddl) CreateTable(ctx context.Context, ident table.Ident, colDefs []*col } func (d *ddl) AlterTable(ctx context.Context, ident table.Ident, specs []*AlterSpecification) (err error) { - // Get database and table. - is := d.GetInformationSchema() - - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return errors.Trace(qerror.ErrDatabaseNotExist) - } - - tbl, err := is.TableByName(ident.Schema, ident.Name) - if err != nil { - return errors.Trace(ErrNotExists) - } - // now we only allow one schema changes at the same time. if len(specs) != 1 { return errors.New("can't run multi schema changes in one DDL") @@ -460,7 +433,7 @@ func (d *ddl) AlterTable(ctx context.Context, ident table.Ident, specs []*AlterS for _, spec := range specs { switch spec.Action { case AlterAddColumn: - err = d.addColumn(ctx, schema, tbl, spec, is.SchemaMetaVersion()) + err = d.AddColumn(ctx, ident, spec) case AlterDropIndex: err = d.DropIndex(ctx, ident, model.NewCIStr(spec.Name)) case AlterAddConstr: @@ -481,116 +454,52 @@ func (d *ddl) AlterTable(ctx context.Context, ident table.Ident, specs []*AlterS return errors.Trace(err) } } + return nil } -// Add a column into table -func (d *ddl) addColumn(ctx context.Context, schema *model.DBInfo, tbl table.Table, spec *AlterSpecification, schemaMetaVersion int64) error { - // Find position - cols := tbl.Cols() - position := len(cols) - name := spec.Column.Name - // Check column name duplicate. - dc := column.FindCol(cols, name) - if dc != nil { - return errors.Errorf("Try to add a column with the same name of an already exists column.") - } - if spec.Position.Type == ColumnPositionFirst { - position = 0 - } else if spec.Position.Type == ColumnPositionAfter { - // Find the mentioned column. - c := column.FindCol(cols, spec.Position.RelativeColumn) - if c == nil { - return errors.Errorf("No such column: %v", name) +func checkColumnConstraint(constraints []*coldef.ConstraintOpt) error { + for _, constraint := range constraints { + switch constraint.Tp { + case coldef.ConstrAutoIncrement, coldef.ConstrForeignKey, coldef.ConstrPrimaryKey, coldef.ConstrUniq, coldef.ConstrUniqKey: + return errors.Errorf("unsupported add column constraint - %s", constraint) } - // Insert position is after the mentioned column. - position = c.Offset + 1 } - // TODO: set constraint - col, _, err := d.buildColumnAndConstraint(position, spec.Column) + + return nil +} + +// AddColumn will add a new column to the table. +func (d *ddl) AddColumn(ctx context.Context, ti table.Ident, spec *AlterSpecification) error { + // Check whether the added column constraints are supported. + err := checkColumnConstraint(spec.Column.Constraints) if err != nil { return errors.Trace(err) } - // insert col into the right place of the column list - newCols := make([]*column.Col, 0, len(cols)+1) - newCols = append(newCols, cols[:position]...) - newCols = append(newCols, col) - newCols = append(newCols, cols[position:]...) - // adjust position - if position != len(cols) { - offsetChange := make(map[int]int) - for i := position + 1; i < len(newCols); i++ { - offsetChange[newCols[i].Offset] = i - newCols[i].Offset = i - } - // Update index offset info - for _, idx := range tbl.Indices() { - for _, c := range idx.Columns { - newOffset, ok := offsetChange[c.Offset] - if ok { - c.Offset = newOffset - } - } - } - } - tb := tbl.(*tables.Table) - tb.Columns = newCols - // TODO: update index - if err = updateOldRows(ctx, tb, col); err != nil { - return errors.Trace(err) + is := d.infoHandle.Get() + schema, ok := is.SchemaByName(ti.Schema) + if !ok { + return errors.Trace(qerror.ErrDatabaseNotExist) } - // update infomation schema - err = kv.RunInNewTxn(d.store, false, func(txn kv.Transaction) error { - t := meta.NewMeta(txn) - err := d.verifySchemaMetaVersion(t, schemaMetaVersion) - if err != nil { - return errors.Trace(err) - } + t, err := is.TableByName(ti.Schema, ti.Name) + if err != nil { + return errors.Trace(ErrNotExists) + } - err = t.UpdateTable(schema.ID, tb.Meta()) - return errors.Trace(err) - }) + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + Type: model.ActionAddColumn, + Args: []interface{}{spec, 0}, + } + err = d.startJob(ctx, job) err = d.onDDLChange(err) return errors.Trace(err) } -func updateOldRows(ctx context.Context, t *tables.Table, col *column.Col) error { - txn, err := ctx.GetTxn(false) - if err != nil { - return errors.Trace(err) - } - it, err := txn.Seek([]byte(t.FirstKey())) - if err != nil { - return errors.Trace(err) - } - defer it.Close() - - prefix := t.KeyPrefix() - for it.Valid() && strings.HasPrefix(it.Key(), prefix) { - handle, err0 := util.DecodeHandleFromRowKey(it.Key()) - if err0 != nil { - return errors.Trace(err0) - } - k := t.RecordKey(handle, col) - - // TODO: check and get timestamp/datetime default value. - // refer to getDefaultValue in stmt/stmts/stmt_helper.go. - if err0 = t.SetColValue(txn, k, col.DefaultValue); err0 != nil { - return errors.Trace(err0) - } - - rk := t.RecordKey(handle, nil) - if err0 = kv.NextUntil(it, util.RowKeyPrefixFilter(rk)); err0 != nil { - return errors.Trace(err0) - } - } - - return nil -} - // DropTable will proceed even if some table in the list does not exists. func (d *ddl) DropTable(ctx context.Context, ti table.Ident) (err error) { is := d.GetInformationSchema() @@ -693,3 +602,15 @@ func (d *ddl) DropIndex(ctx context.Context, ti table.Ident, indexName model.CIS err = d.onDDLChange(err) return errors.Trace(err) } + +// findCol finds column in cols by name. +func findCol(cols []*model.ColumnInfo, name string) *model.ColumnInfo { + name = strings.ToLower(name) + for _, col := range cols { + if col.Name.L == name { + return col + } + } + + return nil +} diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index 2f55dd4d3d..7eb60450a4 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -49,7 +49,7 @@ func (ts *testSuite) SetUpSuite(c *C) { ts.store = store } -func (ts *testSuite) TestT(c *C) { +func (ts *testSuite) TestDDL(c *C) { se, _ := tidb.CreateSession(ts.store) ctx := se.(context.Context) schemaName := model.NewCIStr("test_ddl") @@ -89,6 +89,9 @@ func (ts *testSuite) TestT(c *C) { alterStmt := statement(`alter table t2 add b enum("bb") first`).(*stmts.AlterTableStmt) sessionctx.GetDomain(ctx).DDL().AlterTable(ctx, tbIdent2, alterStmt.Specs) c.Assert(alterStmt.Specs[0].String(), Not(Equals), "") + tb, err = sessionctx.GetDomain(ctx).InfoSchema().TableByName(tbIdent2.Schema, tbIdent2.Name) + c.Assert(err, IsNil) + c.Assert(tb, NotNil) cols, err := tb.Row(ctx, rid0) c.Assert(err, IsNil) c.Assert(len(cols), Equals, 2) diff --git a/ddl/index.go b/ddl/index.go index 4f1dd6cda6..24651fe7ac 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -15,7 +15,6 @@ package ddl import ( "bytes" - "strings" "sync/atomic" "github.com/juju/errors" @@ -30,39 +29,6 @@ import ( "github.com/pingcap/tidb/util/errors2" ) -func (d *ddl) getTableInfo(t *meta.Meta, job *model.Job) (*model.TableInfo, error) { - schemaID := job.SchemaID - tableID := job.TableID - tblInfo, err := t.GetTable(schemaID, tableID) - if errors2.ErrorEqual(err, meta.ErrDBNotExists) { - job.State = model.JobCancelled - return nil, errors.Trace(ErrNotExists) - } else if err != nil { - return nil, errors.Trace(err) - } else if tblInfo == nil { - job.State = model.JobCancelled - return nil, errors.Trace(ErrNotExists) - } - - if tblInfo.State != model.StatePublic { - job.State = model.JobCancelled - return nil, errors.Errorf("table %s is not in public, but %s", tblInfo.Name.L, tblInfo.State) - } - - return tblInfo, nil -} - -// FindCol finds column in cols by name. -func findCol(cols []*model.ColumnInfo, name string) (c *model.ColumnInfo) { - name = strings.ToLower(name) - for _, c = range cols { - if c.Name.L == name { - return - } - } - return nil -} - func buildIndexInfo(tblInfo *model.TableInfo, unique bool, indexName model.CIStr, idxColNames []*coldef.IndexColName) (*model.IndexInfo, error) { for _, col := range tblInfo.Columns { if col.Name.L == indexName.L { @@ -129,7 +95,7 @@ func dropIndexColumnFlag(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) { } } -func (d *ddl) onIndexCreate(t *meta.Meta, job *model.Job) error { +func (d *ddl) onCreateIndex(t *meta.Meta, job *model.Job) error { schemaID := job.SchemaID tblInfo, err := d.getTableInfo(t, job) if err != nil { @@ -189,7 +155,7 @@ func (d *ddl) onIndexCreate(t *meta.Meta, job *model.Job) error { err = t.UpdateTable(schemaID, tblInfo) return errors.Trace(err) case model.StateWriteOnly: - // write only -> public + // write only -> reorganization job.SchemaState = model.StateReorganization indexInfo.State = model.StateReorganization // initialize SnapshotVer to 0 for later reorganization check. @@ -250,7 +216,7 @@ func (d *ddl) onIndexCreate(t *meta.Meta, job *model.Job) error { } } -func (d *ddl) onIndexDrop(t *meta.Meta, job *model.Job) error { +func (d *ddl) onDropIndex(t *meta.Meta, job *model.Job) error { schemaID := job.SchemaID tblInfo, err := d.getTableInfo(t, job) if err != nil { diff --git a/ddl/schema.go b/ddl/schema.go index db4eafe1b0..57d69a87a6 100644 --- a/ddl/schema.go +++ b/ddl/schema.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/tidb/util/errors2" ) -func (d *ddl) onSchemaCreate(t *meta.Meta, job *model.Job) error { +func (d *ddl) onCreateSchema(t *meta.Meta, job *model.Job) error { schemaID := job.SchemaID var name model.CIStr if err := job.DecodeArgs(&name); err != nil { @@ -90,7 +90,7 @@ func (d *ddl) onSchemaCreate(t *meta.Meta, job *model.Job) error { } } -func (d *ddl) onSchemaDrop(t *meta.Meta, job *model.Job) error { +func (d *ddl) onDropSchema(t *meta.Meta, job *model.Job) error { dbInfo, err := t.GetDatabase(job.SchemaID) if err != nil { return errors.Trace(err) diff --git a/ddl/table.go b/ddl/table.go index dbc4a7e1c4..1e03c57ea1 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -23,7 +23,7 @@ import ( "github.com/pingcap/tidb/util/errors2" ) -func (d *ddl) onTableCreate(t *meta.Meta, job *model.Job) error { +func (d *ddl) onCreateTable(t *meta.Meta, job *model.Job) error { schemaID := job.SchemaID tbInfo := &model.TableInfo{} if err := job.DecodeArgs(tbInfo); err != nil { @@ -89,7 +89,7 @@ func (d *ddl) onTableCreate(t *meta.Meta, job *model.Job) error { } } -func (d *ddl) onTableDrop(t *meta.Meta, job *model.Job) error { +func (d *ddl) onDropTable(t *meta.Meta, job *model.Job) error { schemaID := job.SchemaID tableID := job.TableID @@ -170,6 +170,28 @@ func (d *ddl) getTable(t *meta.Meta, schemaID int64, tblInfo *model.TableInfo) ( return tbl, nil } +func (d *ddl) getTableInfo(t *meta.Meta, job *model.Job) (*model.TableInfo, error) { + schemaID := job.SchemaID + tableID := job.TableID + tblInfo, err := t.GetTable(schemaID, tableID) + if errors2.ErrorEqual(err, meta.ErrDBNotExists) { + job.State = model.JobCancelled + return nil, errors.Trace(ErrNotExists) + } else if err != nil { + return nil, errors.Trace(err) + } else if tblInfo == nil { + job.State = model.JobCancelled + return nil, errors.Trace(ErrNotExists) + } + + if tblInfo.State != model.StatePublic { + job.State = model.JobCancelled + return nil, errors.Errorf("table %s is not in public, but %s", tblInfo.Name.L, tblInfo.State) + } + + return tblInfo, nil +} + func (d *ddl) dropTableData(t table.Table) error { // delete table data err := d.delKeysWithPrefix(t.KeyPrefix()) diff --git a/ddl/worker.go b/ddl/worker.go index 3e56a9e892..07a652562e 100644 --- a/ddl/worker.go +++ b/ddl/worker.go @@ -280,19 +280,21 @@ func (d *ddl) runJob(t *meta.Meta, job *model.Job) error { var err error switch job.Type { case model.ActionCreateSchema: - err = d.onSchemaCreate(t, job) + err = d.onCreateSchema(t, job) case model.ActionDropSchema: - err = d.onSchemaDrop(t, job) + err = d.onDropSchema(t, job) case model.ActionCreateTable: - err = d.onTableCreate(t, job) + err = d.onCreateTable(t, job) case model.ActionDropTable: - err = d.onTableDrop(t, job) + err = d.onDropTable(t, job) case model.ActionAddColumn: + err = d.onAddColumn(t, job) case model.ActionDropColumn: + err = d.onDropColumn(t, job) case model.ActionAddIndex: - err = d.onIndexCreate(t, job) + err = d.onCreateIndex(t, job) case model.ActionDropIndex: - err = d.onIndexDrop(t, job) + err = d.onDropIndex(t, job) case model.ActionAddConstraint: log.Fatal("Doesn't support change constraint online") case model.ActionDropConstraint: diff --git a/meta/meta.go b/meta/meta.go index cf440f0793..bf760cdb66 100644 --- a/meta/meta.go +++ b/meta/meta.go @@ -125,13 +125,13 @@ func (m *Meta) parseTableID(key string) (int64, error) { // GenAutoTableID adds step to the auto id of the table and returns the sum. func (m *Meta) GenAutoTableID(dbID int64, tableID int64, step int64) (int64, error) { - // check db exists + // Check if db exists. dbKey := m.dbKey(dbID) if err := m.checkDBExists(dbKey); err != nil { return 0, errors.Trace(err) } - // check table exists + // Check if table exists. tableKey := m.tableKey(tableID) if err := m.checkTableExists(dbKey, tableKey); err != nil { return 0, errors.Trace(err) @@ -239,14 +239,14 @@ func (m *Meta) UpdateDatabase(dbInfo *model.DBInfo) error { // CreateTable creates a table with tableInfo in database. func (m *Meta) CreateTable(dbID int64, tableInfo *model.TableInfo) error { - // first check db exists or not. + // Check if db exists. dbKey := m.dbKey(dbID) if err := m.checkDBExists(dbKey); err != nil { return errors.Trace(err) } + // Check if table exists. tableKey := m.tableKey(tableInfo.ID) - // then check table exists or not if err := m.checkTableNotExists(dbKey, tableKey); err != nil { return errors.Trace(err) } @@ -261,7 +261,7 @@ func (m *Meta) CreateTable(dbID int64, tableInfo *model.TableInfo) error { // DropDatabase drops whole database. func (m *Meta) DropDatabase(dbID int64) error { - // check if db exists. + // Check if db exists. dbKey := m.dbKey(dbID) if err := m.txn.HClear(dbKey); err != nil { return errors.Trace(err) @@ -276,14 +276,14 @@ func (m *Meta) DropDatabase(dbID int64) error { // DropTable drops table in database. func (m *Meta) DropTable(dbID int64, tableID int64) error { - // first check db exists or not. + // Check if db exists. dbKey := m.dbKey(dbID) if err := m.checkDBExists(dbKey); err != nil { return errors.Trace(err) } + // Check if table exists. tableKey := m.tableKey(tableID) - // then check table exists or not if err := m.checkTableExists(dbKey, tableKey); err != nil { return errors.Trace(err) } @@ -301,15 +301,14 @@ func (m *Meta) DropTable(dbID int64, tableID int64) error { // UpdateTable updates the table with table info. func (m *Meta) UpdateTable(dbID int64, tableInfo *model.TableInfo) error { - // first check db exists or not. + // Check if db exists. dbKey := m.dbKey(dbID) if err := m.checkDBExists(dbKey); err != nil { return errors.Trace(err) } + // Check if table exists. tableKey := m.tableKey(tableInfo.ID) - - // then check table exists or not if err := m.checkTableExists(dbKey, tableKey); err != nil { return errors.Trace(err) } @@ -320,7 +319,6 @@ func (m *Meta) UpdateTable(dbID int64, tableInfo *model.TableInfo) error { } err = m.txn.HSet(dbKey, tableKey, data) - return errors.Trace(err) } @@ -390,7 +388,7 @@ func (m *Meta) GetDatabase(dbID int64) (*model.DBInfo, error) { // GetTable gets the table value in database with tableID. func (m *Meta) GetTable(dbID int64, tableID int64) (*model.TableInfo, error) { - // first check db exists or not. + // Check if db exists. dbKey := m.dbKey(dbID) if err := m.checkDBExists(dbKey); err != nil { return nil, errors.Trace(err) diff --git a/plan/plans/from_test.go b/plan/plans/from_test.go index 12eb2e99af..ba0d26cffa 100644 --- a/plan/plans/from_test.go +++ b/plan/plans/from_test.go @@ -83,6 +83,7 @@ func (p *testFromSuit) SetUpSuite(c *C) { Offset: 0, DefaultValue: 0, FieldType: *types.NewFieldType(mysql.TypeLonglong), + State: model.StatePublic, }, }, { @@ -92,6 +93,7 @@ func (p *testFromSuit) SetUpSuite(c *C) { Offset: 1, DefaultValue: nil, FieldType: *types.NewFieldType(mysql.TypeVarchar), + State: model.StatePublic, }, }, } @@ -100,7 +102,8 @@ func (p *testFromSuit) SetUpSuite(c *C) { var i int64 for i = 0; i < 10; i++ { - p.tbl.AddRecord(p, []interface{}{i * 10, "hello"}) + _, err = p.tbl.AddRecord(p, []interface{}{i * 10, "hello"}) + c.Assert(err, IsNil) } } diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index c32c7fbdcf..9f3ecabb20 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/stmt" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/util/errors2" "github.com/pingcap/tidb/util/format" "github.com/pingcap/tidb/util/types" @@ -92,6 +93,7 @@ func (s *InsertValues) execSelect(t table.Table, cols []*column.Col, ctx context return nil, errors.Trace(err) } defer r.Close() + if len(r.GetFields()) != len(cols) { return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields())) } @@ -107,26 +109,29 @@ func (s *InsertValues) execSelect(t table.Table, cols []*column.Col, ctx context if row == nil { break } - data0 := make([]interface{}, len(t.Cols())) - marked := make(map[int]struct{}, len(cols)) - for i, d := range row.Data { - data0[cols[i].Offset] = d - marked[cols[i].Offset] = struct{}{} + + currentRow := make([]interface{}, len(t.Cols())) + marked := make(map[int]struct{}, len(t.Cols())) + for i, data := range row.Data { + offset := cols[i].Offset + currentRow[offset] = data + marked[offset] = struct{}{} } - if err = s.initDefaultValues(ctx, t, data0, marked); err != nil { + if err = s.initDefaultValues(ctx, t, currentRow, marked); err != nil { return nil, errors.Trace(err) } - if err = column.CastValues(ctx, data0, cols); err != nil { + if err = column.CastValues(ctx, currentRow, cols); err != nil { return nil, errors.Trace(err) } - if err = column.CheckNotNull(t.Cols(), data0); err != nil { + if err = column.CheckNotNull(t.Cols(), currentRow); err != nil { return nil, errors.Trace(err) } + var v interface{} - v, err = types.Clone(data0) + v, err = types.Clone(currentRow) if err != nil { return nil, errors.Trace(err) } @@ -190,19 +195,19 @@ func (s *InsertValues) getColumns(tableCols []*column.Col) ([]*column.Col, error return cols, nil } -func (s *InsertValues) getDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) { - m := map[interface{}]interface{}{} - for _, v := range cols { - if value, ok, err := getDefaultValue(ctx, v); ok { +func (s *InsertValues) getColumnDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) { + defaultValMap := map[interface{}]interface{}{} + for _, col := range cols { + if value, ok, err := tables.GetColDefaultValue(ctx, &col.ColumnInfo); ok { if err != nil { return nil, errors.Trace(err) } - m[v.Name.L] = value + defaultValMap[col.Name.L] = value } } - return m, nil + return defaultValMap, nil } func (s *InsertValues) fillValueList() error { @@ -227,6 +232,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) if err != nil { return nil, errors.Trace(err) } + cols, err := s.getColumns(t.Cols()) if err != nil { return nil, errors.Trace(err) @@ -243,10 +249,11 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) return nil, errors.Trace(err) } - m, err := s.getDefaultValues(ctx, t.Cols()) + defaultValMap, err := s.getColumnDefaultValues(ctx, t.Cols()) if err != nil { return nil, errors.Trace(err) } + insertValueCount := len(s.Lists[0]) toUpdateColumns, err := getOnDuplicateUpdateColumns(s.OnDuplicate, t) if err != nil { @@ -258,7 +265,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) return nil, errors.Trace(err) } - row, err := s.getRow(ctx, t, cols, list, m) + row, err := s.fillRowData(ctx, t, cols, list, defaultValMap) if err != nil { return nil, errors.Trace(err) } @@ -276,6 +283,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { return nil, errors.Trace(err) } + if err = execOnDuplicateUpdate(ctx, t, row, h, toUpdateColumns); err != nil { return nil, errors.Trace(err) } @@ -303,36 +311,40 @@ func (s *InsertValues) checkValueCount(insertValueCount, valueCount, num int, co return nil } -func (s *InsertValues) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, error) { - r := make([]interface{}, len(t.Cols())) +func (s *InsertValues) fillRowData(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, evalMap map[interface{}]interface{}) ([]interface{}, error) { + row := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(list)) for i, expr := range list { // For "insert into t values (default)" Default Eval. - m[expression.ExprEvalDefaultName] = cols[i].Name.O + evalMap[expression.ExprEvalDefaultName] = cols[i].Name.O - val, err := expr.Eval(ctx, m) + val, err := expr.Eval(ctx, evalMap) if err != nil { return nil, errors.Trace(err) } - r[cols[i].Offset] = val - marked[cols[i].Offset] = struct{}{} + + offset := cols[i].Offset + row[offset] = val + marked[offset] = struct{}{} } // Clear last insert id. variable.GetSessionVars(ctx).SetLastInsertID(0) - err := s.initDefaultValues(ctx, t, r, marked) + err := s.initDefaultValues(ctx, t, row, marked) if err != nil { return nil, errors.Trace(err) } - if err = column.CastValues(ctx, r, cols); err != nil { - return nil, errors.Trace(err) - } - if err = column.CheckNotNull(t.Cols(), r); err != nil { + + if err = column.CastValues(ctx, row, cols); err != nil { return nil, errors.Trace(err) } - return r, nil + if err = column.CheckNotNull(t.Cols(), row); err != nil { + return nil, errors.Trace(err) + } + + return row, nil } func execOnDuplicateUpdate(ctx context.Context, t table.Table, row []interface{}, h int64, cols map[int]*expression.Assignment) error { @@ -396,7 +408,7 @@ func (s *InsertValues) initDefaultValues(ctx context.Context, t table.Table, row variable.GetSessionVars(ctx).SetLastInsertID(uint64(id)) } else { var value interface{} - value, _, err = getDefaultValue(ctx, c) + value, _, err = tables.GetColDefaultValue(ctx, &c.ColumnInfo) if err != nil { return errors.Trace(err) } diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index 86c5ea41ee..edd86b1dcf 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -61,6 +61,7 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error if err != nil { return nil, errors.Trace(err) } + cols, err := s.getColumns(t.Cols()) if err != nil { return nil, errors.Trace(err) @@ -71,28 +72,33 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error if s.Sel != nil { return s.execSelect(t, cols, ctx) } + // Process `replace ... set x=y ...` if err = s.fillValueList(); err != nil { return nil, errors.Trace(err) } - m, err := s.getDefaultValues(ctx, t.Cols()) + + evalMap, err := s.getColumnDefaultValues(ctx, t.Cols()) if err != nil { return nil, errors.Trace(err) } - replaceValueCount := len(s.Lists[0]) + replaceValueCount := len(s.Lists[0]) for i, list := range s.Lists { if err = s.checkValueCount(replaceValueCount, len(list), i, cols); err != nil { return nil, errors.Trace(err) } - row, err := s.getRow(ctx, t, cols, list, m) + + row, err := s.fillRowData(ctx, t, cols, list, evalMap) if err != nil { return nil, errors.Trace(err) } + h, err := t.AddRecord(ctx, row) if err == nil { continue } + if err != nil && !errors2.ErrorEqual(err, kv.ErrKeyExists) { return nil, errors.Trace(err) } @@ -115,18 +121,20 @@ func replaceRow(ctx context.Context, t table.Table, handle int64, replaceRow []i return errors.Trace(err) } + result := 0 isReplace := false - touched := make([]bool, len(row)) + touched := make(map[int]bool, len(row)) for i, val := range row { - v, err1 := types.Compare(val, replaceRow[i]) - if err1 != nil { - return errors.Trace(err1) + result, err = types.Compare(val, replaceRow[i]) + if err != nil { + return errors.Trace(err) } - if v != 0 { + if result != 0 { touched[i] = true isReplace = true } } + if isReplace { variable.GetSessionVars(ctx).AddAffectedRows(1) if err = t.UpdateRecord(ctx, handle, row, replaceRow, touched); err != nil { diff --git a/stmt/stmts/stmt_helper.go b/stmt/stmts/stmt_helper.go index cca52a8f5b..2b49859b42 100644 --- a/stmt/stmts/stmt_helper.go +++ b/stmt/stmts/stmt_helper.go @@ -14,44 +14,11 @@ package stmts import ( - "github.com/juju/errors" - "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" ) -func getDefaultValue(ctx context.Context, c *column.Col) (interface{}, bool, error) { - // Check no default value flag. - if mysql.HasNoDefaultValueFlag(c.Flag) && c.Tp != mysql.TypeEnum { - return nil, false, errors.Errorf("Field '%s' doesn't have a default value", c.Name) - } - - // Check and get timestamp/datetime default value. - if c.Tp == mysql.TypeTimestamp || c.Tp == mysql.TypeDatetime { - if c.DefaultValue == nil { - return nil, true, nil - } - - value, err := expression.GetTimeValue(ctx, c.DefaultValue, c.Tp, c.Decimal) - if err != nil { - return nil, true, errors.Errorf("Field '%s' get default value fail - %s", c.Name, errors.Trace(err)) - } - - return value, true, nil - } else if c.Tp == mysql.TypeEnum { - // For enum type, if no default value and not null is set, - // the default value is the first element of the enum list - if c.DefaultValue == nil && mysql.HasNotNullFlag(c.Flag) { - return c.FieldType.Elems[0], true, nil - } - } - - return c.DefaultValue, true, nil -} - func getTable(ctx context.Context, tableIdent table.Ident) (table.Table, error) { full := tableIdent.Full(ctx) return sessionctx.GetDomain(ctx).InfoSchema().TableByName(full.Schema, full.Name) diff --git a/stmt/stmts/update.go b/stmt/stmts/update.go index 56e92484f3..28b108a328 100644 --- a/stmt/stmts/update.go +++ b/stmt/stmts/update.go @@ -119,17 +119,17 @@ func getUpdateColumns(assignList []*expression.Assignment, fields []*field.Resul } func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table, - updateColumns map[int]*expression.Assignment, m map[interface{}]interface{}, + updateColumns map[int]*expression.Assignment, evalMap map[interface{}]interface{}, offset int, onDuplicateUpdate bool) error { if err := t.LockRow(ctx, h, true); err != nil { return errors.Trace(err) } - oldData := make([]interface{}, len(t.Cols())) - touched := make([]bool, len(t.Cols())) - copy(oldData, data) - cols := t.Cols() + oldData := data + newData := make([]interface{}, len(cols)) + touched := make(map[int]bool, len(cols)) + copy(newData, oldData) assignExists := false for i, asgn := range updateColumns { @@ -137,47 +137,49 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl // The assign expression is for another table, not this. continue } - val, err := asgn.Expr.Eval(ctx, m) + + val, err := asgn.Expr.Eval(ctx, evalMap) if err != nil { - return err + return errors.Trace(err) } + colIndex := i - offset touched[colIndex] = true - data[colIndex] = val + newData[colIndex] = val assignExists = true } - // no assign list for this table, no need to update. + // If no assign list for this table, no need to update. if !assignExists { return nil } // Check whether new value is valid. - if err := column.CastValues(ctx, data, t.Cols()); err != nil { - return err + if err := column.CastValues(ctx, newData, cols); err != nil { + return errors.Trace(err) } - if err := column.CheckNotNull(t.Cols(), data); err != nil { - return err + if err := column.CheckNotNull(cols, newData); err != nil { + return errors.Trace(err) } // If row is not changed, we should do nothing. rowChanged := false - for i, d := range data { + for i := range oldData { if !touched[i] { continue } - od := oldData[i] - n, err := types.Compare(d, od) + + n, err := types.Compare(newData[i], oldData[i]) if err != nil { return errors.Trace(err) } - if n != 0 { rowChanged = true break } } + if !rowChanged { // See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 { @@ -187,10 +189,11 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl } // Update record to new value and update index. - err := t.UpdateRecord(ctx, h, oldData, data, touched) + err := t.UpdateRecord(ctx, h, oldData, newData, touched) if err != nil { return errors.Trace(err) } + // Record affected rows. if !onDuplicateUpdate { variable.GetSessionVars(ctx).AddAffectedRows(1) @@ -198,6 +201,7 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl variable.GetSessionVars(ctx).AddAffectedRows(2) } + return nil } @@ -253,17 +257,13 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { return nil, errors.Trace(err) } defer p.Close() - updatedRowKeys := make(map[string]bool) - // Get table alias map. fs := p.GetFields() - - columns, err0 := getUpdateColumns(s.List, fs) - if err0 != nil { - return nil, errors.Trace(err0) + columns, err := getUpdateColumns(s.List, fs) + if err != nil { + return nil, errors.Trace(err) } - m := map[interface{}]interface{}{} var records []*plan.Row for { row, err1 := p.Next(ctx) @@ -280,15 +280,17 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { records = append(records, row) } + evalMap := map[interface{}]interface{}{} + updatedRowKeys := make(map[string]bool) for _, row := range records { rowData := row.Data - // Set ExprEvalIdentReferFunc - m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + // Set ExprEvalIdentReferFunc. + evalMap[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { return rowData[index], nil } - // Update rows + // Update rows. offset := 0 for _, entry := range row.RowKeys { tbl := entry.Tbl @@ -302,13 +304,14 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { // Each matching row is updated once, even if it matches the conditions multiple times. continue } + // Update row handle, err2 := util.DecodeHandleFromRowKey(k) if err2 != nil { return nil, errors.Trace(err2) } - err2 = updateRecord(ctx, handle, data, tbl, columns, m, lastOffset, false) + err2 = updateRecord(ctx, handle, data, tbl, columns, evalMap, lastOffset, false) if err2 != nil { return nil, errors.Trace(err2) } diff --git a/table/table.go b/table/table.go index 3e1108c264..79e8e6a5ad 100644 --- a/table/table.go +++ b/table/table.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/sessionctx/db" @@ -56,7 +57,7 @@ type Table interface { // TableName returns table name. TableName() model.CIStr - // Cols returns the columns of the table. + // Cols returns the columns of the table which is used in select. Cols() []*column.Col // Indices returns the indices of the table. @@ -87,7 +88,7 @@ type Table interface { AddRecord(ctx context.Context, r []interface{}) (recordID int64, err error) // UpdateRecord updates a row in the table. - UpdateRecord(ctx context.Context, h int64, currData []interface{}, newData []interface{}, touched []bool) error + UpdateRecord(ctx context.Context, h int64, currData []interface{}, newData []interface{}, touched map[int]bool) error // TableID returns the ID of the table. TableID() int64 @@ -107,6 +108,10 @@ type Table interface { // LockRow locks a row. // If update is true, set row lock key to current txn. LockRow(ctx context.Context, h int64, update bool) error + + // SetColValue sets the column value. + // If the column is untouched, we don't need to do this. + SetColValue(txn kv.Transaction, key []byte, data interface{}) error } // TableFromMeta builds a table.Table from *model.TableInfo. diff --git a/table/tables/tables.go b/table/tables/tables.go index 300a7746f1..198b8f5965 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -44,11 +44,13 @@ type Table struct { Name model.CIStr Columns []*column.Col - indices []*column.IndexedCol - recordPrefix string - indexPrefix string - alloc autoid.Allocator - state model.SchemaState + publicColumns []*column.Col + writableColumns []*column.Col + indices []*column.IndexedCol + recordPrefix string + indexPrefix string + alloc autoid.Allocator + state model.SchemaState } // TableFromMeta creates a Table instance from model.TableInfo. @@ -56,8 +58,8 @@ func TableFromMeta(alloc autoid.Allocator, tblInfo *model.TableInfo) table.Table t := NewTable(tblInfo.ID, tblInfo.Name.O, nil, alloc) for _, colInfo := range tblInfo.Columns { - c := column.Col{ColumnInfo: *colInfo} - t.Columns = append(t.Columns, &c) + col := &column.Col{ColumnInfo: *colInfo} + t.Columns = append(t.Columns, col) } for _, idxInfo := range tblInfo.Indices { @@ -84,6 +86,9 @@ func NewTable(tableID int64, tableName string, cols []*column.Col, alloc autoid. Columns: cols, state: model.StatePublic, } + + t.publicColumns = t.Cols() + t.writableColumns = t.writableCols() return t } @@ -129,7 +134,35 @@ func (t *Table) Meta() *model.TableInfo { // Cols implements table.Table Cols interface. func (t *Table) Cols() []*column.Col { - return t.Columns + if len(t.publicColumns) > 0 { + return t.publicColumns + } + + t.publicColumns = make([]*column.Col, 0, len(t.Columns)) + for _, col := range t.Columns { + if col.State == model.StatePublic { + t.publicColumns = append(t.publicColumns, col) + } + } + + return t.publicColumns +} + +func (t *Table) writableCols() []*column.Col { + if len(t.writableColumns) > 0 { + return t.writableColumns + } + + t.writableColumns = make([]*column.Col, 0, len(t.Columns)) + for _, col := range t.Columns { + if col.State == model.StateDeleteOnly { + continue + } + + t.writableColumns = append(t.writableColumns, col) + } + + return t.writableColumns } func (t *Table) unflatten(rec interface{}, col *column.Col) (interface{}, error) { @@ -237,43 +270,47 @@ func (t *Table) Truncate(ctx context.Context) error { } // UpdateRecord implements table.Table UpdateRecord interface. -func (t *Table) UpdateRecord(ctx context.Context, h int64, currData []interface{}, newData []interface{}, touched []bool) error { - // if they are not set, and other data are changed, they will be updated by current timestamp too. - // set on update value - err := t.setOnUpdateData(ctx, touched, newData) +func (t *Table) UpdateRecord(ctx context.Context, h int64, oldData []interface{}, newData []interface{}, touched map[int]bool) error { + // We should check whether this table has on update column which state is write only. + currentData := make([]interface{}, len(t.writableCols())) + copy(currentData, newData) + + // If they are not set, and other data are changed, they will be updated by current timestamp too. + err := t.setOnUpdateData(ctx, touched, currentData) if err != nil { return errors.Trace(err) } // set new value - if err := t.setNewData(ctx, h, newData); err != nil { + if err := t.setNewData(ctx, h, touched, currentData); err != nil { return errors.Trace(err) } // rebuild index - if err := t.rebuildIndices(ctx, h, touched, currData, newData); err != nil { + if err := t.rebuildIndices(ctx, h, touched, oldData, currentData); err != nil { return errors.Trace(err) } + return nil } -func (t *Table) setOnUpdateData(ctx context.Context, touched []bool, data []interface{}) error { - ucols := column.FindOnUpdateCols(t.Cols()) - for _, c := range ucols { - if !touched[c.Offset] { - v, err := expression.GetTimeValue(ctx, expression.CurrentTimestamp, c.Tp, c.Decimal) +func (t *Table) setOnUpdateData(ctx context.Context, touched map[int]bool, data []interface{}) error { + ucols := column.FindOnUpdateCols(t.writableCols()) + for _, col := range ucols { + if !touched[col.Offset] { + value, err := expression.GetTimeValue(ctx, expression.CurrentTimestamp, col.Tp, col.Decimal) if err != nil { return errors.Trace(err) } - data[c.Offset] = v - touched[c.Offset] = true + + data[col.Offset] = value + touched[col.Offset] = true } } return nil } -// SetColValue sets the column value. -// If the column untouched, we don't need to do this. +// SetColValue implements table.Table SetColValue interface. func (t *Table) SetColValue(txn kv.Transaction, key []byte, data interface{}) error { v, err := t.EncodeValue(data) if err != nil { @@ -285,21 +322,27 @@ func (t *Table) SetColValue(txn kv.Transaction, key []byte, data interface{}) er return nil } -func (t *Table) setNewData(ctx context.Context, h int64, data []interface{}) error { +func (t *Table) setNewData(ctx context.Context, h int64, touched map[int]bool, data []interface{}) error { txn, err := ctx.GetTxn(false) if err != nil { return errors.Trace(err) } + for _, col := range t.Cols() { + if !touched[col.Offset] { + continue + } + k := t.RecordKey(h, col) if err := t.SetColValue(txn, k, data[col.Offset]); err != nil { return errors.Trace(err) } } + return nil } -func (t *Table) rebuildIndices(ctx context.Context, h int64, touched []bool, oldData, newData []interface{}) error { +func (t *Table) rebuildIndices(ctx context.Context, h int64, touched map[int]bool, oldData []interface{}, newData []interface{}) error { for _, idx := range t.Indices() { idxTouched := false for _, ic := range idx.Columns { @@ -325,6 +368,7 @@ func (t *Table) rebuildIndices(ctx context.Context, h int64, touched []bool, old if err != nil { return errors.Trace(err) } + if err := t.BuildIndexForRow(ctx, h, newVs, idx); err != nil { return errors.Trace(err) } @@ -376,13 +420,25 @@ func (t *Table) AddRecord(ctx context.Context, r []interface{}) (recordID int64, return 0, errors.Trace(err) } - // column key -> column value - for _, c := range t.Cols() { - k := t.RecordKey(recordID, c) - if err := t.SetColValue(txn, k, r[c.Offset]); err != nil { + // Set public and write only column value. + for _, col := range t.writableCols() { + var value interface{} + key := t.RecordKey(recordID, col) + if col.State == model.StateWriteOnly { + value, _, err = GetColDefaultValue(ctx, &col.ColumnInfo) + if err != nil { + return 0, errors.Trace(err) + } + } else { + value = r[col.Offset] + } + + err = t.SetColValue(txn, key, value) + if err != nil { return 0, errors.Trace(err) } } + variable.GetSessionVars(ctx).AddAffectedRows(1) return recordID, nil } @@ -412,20 +468,25 @@ func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*column.Col) ([ if err != nil { return nil, errors.Trace(err) } + // use the length of t.Cols() for alignment v := make([]interface{}, len(t.Cols())) - for _, c := range cols { - k := t.RecordKey(h, c) + for _, col := range cols { + if col.State != model.StatePublic { + return nil, errors.Errorf("Cannot use none public column - %v", cols) + } + + k := t.RecordKey(h, col) data, err := txn.Get([]byte(k)) if err != nil { return nil, errors.Trace(err) } - val, err := t.DecodeValue(data, c) + val, err := t.DecodeValue(data, col) if err != nil { return nil, errors.Trace(err) } - v[c.Offset] = val + v[col.Offset] = val } return v, nil } @@ -471,10 +532,16 @@ func (t *Table) RemoveRow(ctx context.Context, h int64) error { return errors.Trace(err) } // Remove row's colume one by one - for _, col := range t.Cols() { + for _, col := range t.Columns { k := t.RecordKey(h, col) err := txn.Delete([]byte(k)) if err != nil { + if col.State != model.StatePublic && errors2.ErrorEqual(err, kv.ErrNotExist) { + // If the column is not in public state, we may have not added the column, + // or already deleted the column, so skip ErrNotExist error. + continue + } + return errors.Trace(err) } } @@ -600,6 +667,36 @@ func (t *Table) AllocAutoID() (int64, error) { return t.alloc.Alloc(t.ID) } +// GetColDefaultValue gets default value of the column. +func GetColDefaultValue(ctx context.Context, col *model.ColumnInfo) (interface{}, bool, error) { + // Check no default value flag. + if mysql.HasNoDefaultValueFlag(col.Flag) && col.Tp != mysql.TypeEnum { + return nil, false, errors.Errorf("Field '%s' doesn't have a default value", col.Name) + } + + // Check and get timestamp/datetime default value. + if col.Tp == mysql.TypeTimestamp || col.Tp == mysql.TypeDatetime { + if col.DefaultValue == nil { + return nil, true, nil + } + + value, err := expression.GetTimeValue(ctx, col.DefaultValue, col.Tp, col.Decimal) + if err != nil { + return nil, true, errors.Errorf("Field '%s' get default value fail - %s", col.Name, errors.Trace(err)) + } + + return value, true, nil + } else if col.Tp == mysql.TypeEnum { + // For enum type, if no default value and not null is set, + // the default value is the first element of the enum list + if col.DefaultValue == nil && mysql.HasNotNullFlag(col.Flag) { + return col.FieldType.Elems[0], true, nil + } + } + + return col.DefaultValue, true, nil +} + func init() { table.TableFromMeta = TableFromMeta } diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index edaeeb8c2d..424a2db4cd 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -82,7 +82,7 @@ func (ts *testSuite) TestBasic(c *C) { _, err = tb.AddRecord(ctx, []interface{}{2, "abc"}) c.Assert(err, NotNil) - c.Assert(tb.UpdateRecord(ctx, rid, []interface{}{1, "abc"}, []interface{}{1, "cba"}, []bool{false, true}), IsNil) + c.Assert(tb.UpdateRecord(ctx, rid, []interface{}{1, "abc"}, []interface{}{1, "cba"}, map[int]bool{0: false, 1: true}), IsNil) tb.IterRecords(ctx, tb.FirstKey(), tb.Cols(), func(h int64, data []interface{}, cols []*column.Col) (bool, error) { return true, nil