Merge pull request #505 from pingcap/qiuyesuifeng/schema-change-add-column

Add column schema change support.
This commit is contained in:
siddontang
2015-11-04 18:49:23 +08:00
18 changed files with 631 additions and 333 deletions

View File

@ -55,10 +55,14 @@ func (c *Col) String() string {
}
// FindCol finds column in cols by name.
func FindCol(cols []*Col, name string) (c *Col) {
for _, c = range cols {
if strings.EqualFold(c.Name.O, name) {
return
func FindCol(cols []*Col, name string) *Col {
for _, col := range cols {
if col.State != model.StatePublic {
continue
}
if strings.EqualFold(col.Name.O, name) {
return col
}
}
return nil
@ -82,9 +86,13 @@ func FindCols(cols []*Col, names []string) ([]*Col, error) {
// FindOnUpdateCols finds columns which have OnUpdateNow flag.
func FindOnUpdateCols(cols []*Col) []*Col {
var rcols []*Col
for _, c := range cols {
if mysql.HasOnUpdateNowFlag(c.Flag) {
rcols = append(rcols, c)
for _, col := range cols {
if col.State != model.StatePublic {
continue
}
if mysql.HasOnUpdateNowFlag(col.Flag) {
rcols = append(rcols, col)
}
}
@ -181,8 +189,12 @@ func ColDescFieldNames(full bool) []string {
// CheckOnce checks if there are duplicated column names in cols.
func CheckOnce(cols []*Col) error {
m := map[string]struct{}{}
for _, v := range cols {
name := v.Name
for _, col := range cols {
if col.State != model.StatePublic {
continue
}
name := col.Name
_, ok := m[name.L]
if ok {
return errors.Errorf("column specified twice - %s", name)

View File

@ -34,6 +34,7 @@ func (s *testColumnSuite) TestString(c *C) {
col := &Col{
model.ColumnInfo{
FieldType: *types.NewFieldType(mysql.TypeTiny),
State: model.StatePublic,
},
}
col.Flen = 2
@ -109,7 +110,8 @@ func (s *testColumnSuite) TestDesc(c *C) {
func newCol(name string) *Col {
return &Col{
model.ColumnInfo{
Name: model.NewCIStr(name),
Name: model.NewCIStr(name),
State: model.StatePublic,
},
}
}

277
ddl/column.go Normal file
View File

@ -0,0 +1,277 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package ddl
import (
"sync/atomic"
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/util/errors2"
)
func (d *ddl) adjustColumnOffset(columns []*model.ColumnInfo, indices []*model.IndexInfo, offset int) {
offsetChanged := make(map[int]int)
for i := offset; i < len(columns); i++ {
offsetChanged[columns[i].Offset] = i
columns[i].Offset = i
}
// Update index column offset info.
for _, idx := range indices {
for _, col := range idx.Columns {
newOffset, ok := offsetChanged[col.Offset]
if ok {
col.Offset = newOffset
}
}
}
}
func (d *ddl) addColumn(tblInfo *model.TableInfo, spec *AlterSpecification) (*model.ColumnInfo, int, error) {
// Check column name duplicate.
cols := tblInfo.Columns
position := len(cols)
// Get column position.
if spec.Position.Type == ColumnPositionFirst {
position = 0
} else if spec.Position.Type == ColumnPositionAfter {
c := findCol(cols, spec.Position.RelativeColumn)
if c == nil {
return nil, 0, errors.Errorf("No such column: %v", spec.Column.Name)
}
// Insert position is after the mentioned column.
position = c.Offset + 1
}
// TODO: set constraint
col, _, err := d.buildColumnAndConstraint(position, spec.Column)
if err != nil {
return nil, 0, errors.Trace(err)
}
colInfo := &col.ColumnInfo
colInfo.State = model.StateNone
// To support add column asynchronous, we should mark its offset as the last column.
// So that we can use origin column offset to get value from row.
colInfo.Offset = len(cols)
// Insert col into the right place of the column list.
newCols := make([]*model.ColumnInfo, 0, len(cols)+1)
newCols = append(newCols, cols[:position]...)
newCols = append(newCols, colInfo)
newCols = append(newCols, cols[position:]...)
tblInfo.Columns = newCols
return colInfo, position, nil
}
func (d *ddl) onAddColumn(t *meta.Meta, job *model.Job) error {
schemaID := job.SchemaID
tblInfo, err := d.getTableInfo(t, job)
if err != nil {
return errors.Trace(err)
}
spec := &AlterSpecification{}
offset := 0
err = job.DecodeArgs(&spec, &offset)
if err != nil {
job.State = model.JobCancelled
return errors.Trace(err)
}
columnInfo := findCol(tblInfo.Columns, spec.Column.Name)
if columnInfo != nil {
if columnInfo.State == model.StatePublic {
// we already have a column with same column name
job.State = model.JobCancelled
return errors.Errorf("ADD COLUMN: column already exist %s", spec.Column.Name)
}
} else {
columnInfo, offset, err = d.addColumn(tblInfo, spec)
if err != nil {
job.State = model.JobCancelled
return errors.Trace(err)
}
// Set offset arg to job.
if offset != 0 {
job.Args = []interface{}{spec, offset}
}
}
_, err = t.GenSchemaVersion()
if err != nil {
return errors.Trace(err)
}
switch columnInfo.State {
case model.StateNone:
// none -> delete only
job.SchemaState = model.StateDeleteOnly
columnInfo.State = model.StateDeleteOnly
err = t.UpdateTable(schemaID, tblInfo)
return errors.Trace(err)
case model.StateDeleteOnly:
// delete only -> write only
job.SchemaState = model.StateWriteOnly
columnInfo.State = model.StateWriteOnly
err = t.UpdateTable(schemaID, tblInfo)
return errors.Trace(err)
case model.StateWriteOnly:
// write only -> reorganization
job.SchemaState = model.StateReorganization
columnInfo.State = model.StateReorganization
// initialize SnapshotVer to 0 for later reorganization check.
job.SnapshotVer = 0
// initialize reorg handle to 0
job.ReorgHandle = 0
atomic.StoreInt64(&d.reorgHandle, 0)
err = t.UpdateTable(schemaID, tblInfo)
return errors.Trace(err)
case model.StateReorganization:
// reorganization -> public
// get the current version for reorganization if we don't have
if job.SnapshotVer == 0 {
var ver kv.Version
ver, err = d.store.CurrentVersion()
if err != nil {
return errors.Trace(err)
}
job.SnapshotVer = ver.Ver
}
tbl, err := d.getTable(t, schemaID, tblInfo)
if err != nil {
return errors.Trace(err)
}
err = d.runReorgJob(func() error {
return d.backfillColumn(tbl, columnInfo, job.SnapshotVer, job.ReorgHandle)
})
// backfillColumn updates ReorgHandle after one batch.
// so we update the job ReorgHandle here.
job.ReorgHandle = atomic.LoadInt64(&d.reorgHandle)
if errors2.ErrorEqual(err, errWaitReorgTimeout) {
// if timeout, we should return, check for the owner and re-wait job done.
return nil
}
if err != nil {
return errors.Trace(err)
}
// Adjust column offset.
d.adjustColumnOffset(tblInfo.Columns, tblInfo.Indices, offset)
columnInfo.State = model.StatePublic
if err = t.UpdateTable(schemaID, tblInfo); err != nil {
return errors.Trace(err)
}
// finish this job
job.SchemaState = model.StatePublic
job.State = model.JobDone
return nil
default:
return errors.Errorf("invalid column state %v", columnInfo.State)
}
}
func (d *ddl) onDropColumn(t *meta.Meta, job *model.Job) error {
// TODO: complete it.
return nil
}
func (d *ddl) backfillColumn(t table.Table, columnInfo *model.ColumnInfo, version uint64, seekHandle int64) error {
for {
handles, err := d.getSnapshotRows(t, version, seekHandle)
if err != nil {
return errors.Trace(err)
} else if len(handles) == 0 {
return nil
}
seekHandle = handles[len(handles)-1] + 1
// TODO: save seekHandle in reorganization job, so we can resume this job later from this handle.
err = d.backfillColumnData(t, columnInfo, handles)
if err != nil {
return errors.Trace(err)
}
// update reorgHandle here after every successful batch.
atomic.StoreInt64(&d.reorgHandle, seekHandle)
}
}
func (d *ddl) backfillColumnData(t table.Table, columnInfo *model.ColumnInfo, handles []int64) error {
for _, handle := range handles {
log.Info("backfill column...", handle)
err := kv.RunInNewTxn(d.store, true, func(txn kv.Transaction) error {
// First check if row exists.
exist, err := checkRowExist(txn, t, handle)
if err != nil {
return errors.Trace(err)
} else if !exist {
// If row doesn't exist, skip it.
return nil
}
backfillKey := t.RecordKey(handle, &column.Col{ColumnInfo: *columnInfo})
_, err = txn.Get(backfillKey)
if err != nil && !kv.IsErrNotFound(err) {
return errors.Trace(err)
}
// If row column doesn't exist, we need to backfill column.
// Lock row first.
err = txn.LockKeys(t.RecordKey(handle, nil))
if err != nil {
return errors.Trace(err)
}
value, _, err := tables.GetColDefaultValue(nil, columnInfo)
if err != nil {
return errors.Trace(err)
}
err = t.SetColValue(txn, backfillKey, value)
if err != nil {
return errors.Trace(err)
}
return nil
})
if err != nil {
return errors.Trace(err)
}
}
return nil
}

View File

@ -34,8 +34,6 @@ import (
"github.com/pingcap/tidb/parser/coldef"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/charset"
qerror "github.com/pingcap/tidb/util/errors"
"github.com/twinj/uuid"
@ -182,20 +180,6 @@ func (d *ddl) CreateSchema(ctx context.Context, schema model.CIStr) (err error)
return errors.Trace(err)
}
func (d *ddl) verifySchemaMetaVersion(t *meta.Meta, schemaMetaVersion int64) error {
curVer, err := t.GetSchemaVersion()
if err != nil {
return errors.Trace(err)
}
if curVer != schemaMetaVersion {
return errors.Errorf("Schema changed, our version %d, but got %d", schemaMetaVersion, curVer)
}
// Increment version.
_, err = t.GenSchemaVersion()
return errors.Trace(err)
}
func (d *ddl) DropSchema(ctx context.Context, schema model.CIStr) (err error) {
is := d.GetInformationSchema()
old, ok := is.SchemaByName(schema)
@ -288,7 +272,7 @@ func (d *ddl) buildColumnsAndConstraints(colDefs []*coldef.ColumnDef, constraint
}
func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*column.Col, []*coldef.TableConstraint, error) {
// set charset
// Set charset.
if len(colDef.Tp.Charset) == 0 {
switch colDef.Tp.Tp {
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
@ -298,15 +282,17 @@ func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*c
colDef.Tp.Collate = charset.CharsetBin
}
}
// convert colDef into col
col, cts, err := coldef.ColumnDefToCol(offset, colDef)
if err != nil {
return nil, nil, errors.Trace(err)
}
col.ID, err = d.genGlobalID()
if err != nil {
return nil, nil, errors.Trace(err)
}
return col, cts, nil
}
@ -439,19 +425,6 @@ func (d *ddl) CreateTable(ctx context.Context, ident table.Ident, colDefs []*col
}
func (d *ddl) AlterTable(ctx context.Context, ident table.Ident, specs []*AlterSpecification) (err error) {
// Get database and table.
is := d.GetInformationSchema()
schema, ok := is.SchemaByName(ident.Schema)
if !ok {
return errors.Trace(qerror.ErrDatabaseNotExist)
}
tbl, err := is.TableByName(ident.Schema, ident.Name)
if err != nil {
return errors.Trace(ErrNotExists)
}
// now we only allow one schema changes at the same time.
if len(specs) != 1 {
return errors.New("can't run multi schema changes in one DDL")
@ -460,7 +433,7 @@ func (d *ddl) AlterTable(ctx context.Context, ident table.Ident, specs []*AlterS
for _, spec := range specs {
switch spec.Action {
case AlterAddColumn:
err = d.addColumn(ctx, schema, tbl, spec, is.SchemaMetaVersion())
err = d.AddColumn(ctx, ident, spec)
case AlterDropIndex:
err = d.DropIndex(ctx, ident, model.NewCIStr(spec.Name))
case AlterAddConstr:
@ -481,116 +454,52 @@ func (d *ddl) AlterTable(ctx context.Context, ident table.Ident, specs []*AlterS
return errors.Trace(err)
}
}
return nil
}
// Add a column into table
func (d *ddl) addColumn(ctx context.Context, schema *model.DBInfo, tbl table.Table, spec *AlterSpecification, schemaMetaVersion int64) error {
// Find position
cols := tbl.Cols()
position := len(cols)
name := spec.Column.Name
// Check column name duplicate.
dc := column.FindCol(cols, name)
if dc != nil {
return errors.Errorf("Try to add a column with the same name of an already exists column.")
}
if spec.Position.Type == ColumnPositionFirst {
position = 0
} else if spec.Position.Type == ColumnPositionAfter {
// Find the mentioned column.
c := column.FindCol(cols, spec.Position.RelativeColumn)
if c == nil {
return errors.Errorf("No such column: %v", name)
func checkColumnConstraint(constraints []*coldef.ConstraintOpt) error {
for _, constraint := range constraints {
switch constraint.Tp {
case coldef.ConstrAutoIncrement, coldef.ConstrForeignKey, coldef.ConstrPrimaryKey, coldef.ConstrUniq, coldef.ConstrUniqKey:
return errors.Errorf("unsupported add column constraint - %s", constraint)
}
// Insert position is after the mentioned column.
position = c.Offset + 1
}
// TODO: set constraint
col, _, err := d.buildColumnAndConstraint(position, spec.Column)
return nil
}
// AddColumn will add a new column to the table.
func (d *ddl) AddColumn(ctx context.Context, ti table.Ident, spec *AlterSpecification) error {
// Check whether the added column constraints are supported.
err := checkColumnConstraint(spec.Column.Constraints)
if err != nil {
return errors.Trace(err)
}
// insert col into the right place of the column list
newCols := make([]*column.Col, 0, len(cols)+1)
newCols = append(newCols, cols[:position]...)
newCols = append(newCols, col)
newCols = append(newCols, cols[position:]...)
// adjust position
if position != len(cols) {
offsetChange := make(map[int]int)
for i := position + 1; i < len(newCols); i++ {
offsetChange[newCols[i].Offset] = i
newCols[i].Offset = i
}
// Update index offset info
for _, idx := range tbl.Indices() {
for _, c := range idx.Columns {
newOffset, ok := offsetChange[c.Offset]
if ok {
c.Offset = newOffset
}
}
}
}
tb := tbl.(*tables.Table)
tb.Columns = newCols
// TODO: update index
if err = updateOldRows(ctx, tb, col); err != nil {
return errors.Trace(err)
is := d.infoHandle.Get()
schema, ok := is.SchemaByName(ti.Schema)
if !ok {
return errors.Trace(qerror.ErrDatabaseNotExist)
}
// update infomation schema
err = kv.RunInNewTxn(d.store, false, func(txn kv.Transaction) error {
t := meta.NewMeta(txn)
err := d.verifySchemaMetaVersion(t, schemaMetaVersion)
if err != nil {
return errors.Trace(err)
}
t, err := is.TableByName(ti.Schema, ti.Name)
if err != nil {
return errors.Trace(ErrNotExists)
}
err = t.UpdateTable(schema.ID, tb.Meta())
return errors.Trace(err)
})
job := &model.Job{
SchemaID: schema.ID,
TableID: t.Meta().ID,
Type: model.ActionAddColumn,
Args: []interface{}{spec, 0},
}
err = d.startJob(ctx, job)
err = d.onDDLChange(err)
return errors.Trace(err)
}
func updateOldRows(ctx context.Context, t *tables.Table, col *column.Col) error {
txn, err := ctx.GetTxn(false)
if err != nil {
return errors.Trace(err)
}
it, err := txn.Seek([]byte(t.FirstKey()))
if err != nil {
return errors.Trace(err)
}
defer it.Close()
prefix := t.KeyPrefix()
for it.Valid() && strings.HasPrefix(it.Key(), prefix) {
handle, err0 := util.DecodeHandleFromRowKey(it.Key())
if err0 != nil {
return errors.Trace(err0)
}
k := t.RecordKey(handle, col)
// TODO: check and get timestamp/datetime default value.
// refer to getDefaultValue in stmt/stmts/stmt_helper.go.
if err0 = t.SetColValue(txn, k, col.DefaultValue); err0 != nil {
return errors.Trace(err0)
}
rk := t.RecordKey(handle, nil)
if err0 = kv.NextUntil(it, util.RowKeyPrefixFilter(rk)); err0 != nil {
return errors.Trace(err0)
}
}
return nil
}
// DropTable will proceed even if some table in the list does not exists.
func (d *ddl) DropTable(ctx context.Context, ti table.Ident) (err error) {
is := d.GetInformationSchema()
@ -693,3 +602,15 @@ func (d *ddl) DropIndex(ctx context.Context, ti table.Ident, indexName model.CIS
err = d.onDDLChange(err)
return errors.Trace(err)
}
// findCol finds column in cols by name.
func findCol(cols []*model.ColumnInfo, name string) *model.ColumnInfo {
name = strings.ToLower(name)
for _, col := range cols {
if col.Name.L == name {
return col
}
}
return nil
}

View File

@ -49,7 +49,7 @@ func (ts *testSuite) SetUpSuite(c *C) {
ts.store = store
}
func (ts *testSuite) TestT(c *C) {
func (ts *testSuite) TestDDL(c *C) {
se, _ := tidb.CreateSession(ts.store)
ctx := se.(context.Context)
schemaName := model.NewCIStr("test_ddl")
@ -89,6 +89,9 @@ func (ts *testSuite) TestT(c *C) {
alterStmt := statement(`alter table t2 add b enum("bb") first`).(*stmts.AlterTableStmt)
sessionctx.GetDomain(ctx).DDL().AlterTable(ctx, tbIdent2, alterStmt.Specs)
c.Assert(alterStmt.Specs[0].String(), Not(Equals), "")
tb, err = sessionctx.GetDomain(ctx).InfoSchema().TableByName(tbIdent2.Schema, tbIdent2.Name)
c.Assert(err, IsNil)
c.Assert(tb, NotNil)
cols, err := tb.Row(ctx, rid0)
c.Assert(err, IsNil)
c.Assert(len(cols), Equals, 2)

View File

@ -15,7 +15,6 @@ package ddl
import (
"bytes"
"strings"
"sync/atomic"
"github.com/juju/errors"
@ -30,39 +29,6 @@ import (
"github.com/pingcap/tidb/util/errors2"
)
func (d *ddl) getTableInfo(t *meta.Meta, job *model.Job) (*model.TableInfo, error) {
schemaID := job.SchemaID
tableID := job.TableID
tblInfo, err := t.GetTable(schemaID, tableID)
if errors2.ErrorEqual(err, meta.ErrDBNotExists) {
job.State = model.JobCancelled
return nil, errors.Trace(ErrNotExists)
} else if err != nil {
return nil, errors.Trace(err)
} else if tblInfo == nil {
job.State = model.JobCancelled
return nil, errors.Trace(ErrNotExists)
}
if tblInfo.State != model.StatePublic {
job.State = model.JobCancelled
return nil, errors.Errorf("table %s is not in public, but %s", tblInfo.Name.L, tblInfo.State)
}
return tblInfo, nil
}
// FindCol finds column in cols by name.
func findCol(cols []*model.ColumnInfo, name string) (c *model.ColumnInfo) {
name = strings.ToLower(name)
for _, c = range cols {
if c.Name.L == name {
return
}
}
return nil
}
func buildIndexInfo(tblInfo *model.TableInfo, unique bool, indexName model.CIStr, idxColNames []*coldef.IndexColName) (*model.IndexInfo, error) {
for _, col := range tblInfo.Columns {
if col.Name.L == indexName.L {
@ -129,7 +95,7 @@ func dropIndexColumnFlag(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) {
}
}
func (d *ddl) onIndexCreate(t *meta.Meta, job *model.Job) error {
func (d *ddl) onCreateIndex(t *meta.Meta, job *model.Job) error {
schemaID := job.SchemaID
tblInfo, err := d.getTableInfo(t, job)
if err != nil {
@ -189,7 +155,7 @@ func (d *ddl) onIndexCreate(t *meta.Meta, job *model.Job) error {
err = t.UpdateTable(schemaID, tblInfo)
return errors.Trace(err)
case model.StateWriteOnly:
// write only -> public
// write only -> reorganization
job.SchemaState = model.StateReorganization
indexInfo.State = model.StateReorganization
// initialize SnapshotVer to 0 for later reorganization check.
@ -250,7 +216,7 @@ func (d *ddl) onIndexCreate(t *meta.Meta, job *model.Job) error {
}
}
func (d *ddl) onIndexDrop(t *meta.Meta, job *model.Job) error {
func (d *ddl) onDropIndex(t *meta.Meta, job *model.Job) error {
schemaID := job.SchemaID
tblInfo, err := d.getTableInfo(t, job)
if err != nil {

View File

@ -22,7 +22,7 @@ import (
"github.com/pingcap/tidb/util/errors2"
)
func (d *ddl) onSchemaCreate(t *meta.Meta, job *model.Job) error {
func (d *ddl) onCreateSchema(t *meta.Meta, job *model.Job) error {
schemaID := job.SchemaID
var name model.CIStr
if err := job.DecodeArgs(&name); err != nil {
@ -90,7 +90,7 @@ func (d *ddl) onSchemaCreate(t *meta.Meta, job *model.Job) error {
}
}
func (d *ddl) onSchemaDrop(t *meta.Meta, job *model.Job) error {
func (d *ddl) onDropSchema(t *meta.Meta, job *model.Job) error {
dbInfo, err := t.GetDatabase(job.SchemaID)
if err != nil {
return errors.Trace(err)

View File

@ -23,7 +23,7 @@ import (
"github.com/pingcap/tidb/util/errors2"
)
func (d *ddl) onTableCreate(t *meta.Meta, job *model.Job) error {
func (d *ddl) onCreateTable(t *meta.Meta, job *model.Job) error {
schemaID := job.SchemaID
tbInfo := &model.TableInfo{}
if err := job.DecodeArgs(tbInfo); err != nil {
@ -89,7 +89,7 @@ func (d *ddl) onTableCreate(t *meta.Meta, job *model.Job) error {
}
}
func (d *ddl) onTableDrop(t *meta.Meta, job *model.Job) error {
func (d *ddl) onDropTable(t *meta.Meta, job *model.Job) error {
schemaID := job.SchemaID
tableID := job.TableID
@ -170,6 +170,28 @@ func (d *ddl) getTable(t *meta.Meta, schemaID int64, tblInfo *model.TableInfo) (
return tbl, nil
}
func (d *ddl) getTableInfo(t *meta.Meta, job *model.Job) (*model.TableInfo, error) {
schemaID := job.SchemaID
tableID := job.TableID
tblInfo, err := t.GetTable(schemaID, tableID)
if errors2.ErrorEqual(err, meta.ErrDBNotExists) {
job.State = model.JobCancelled
return nil, errors.Trace(ErrNotExists)
} else if err != nil {
return nil, errors.Trace(err)
} else if tblInfo == nil {
job.State = model.JobCancelled
return nil, errors.Trace(ErrNotExists)
}
if tblInfo.State != model.StatePublic {
job.State = model.JobCancelled
return nil, errors.Errorf("table %s is not in public, but %s", tblInfo.Name.L, tblInfo.State)
}
return tblInfo, nil
}
func (d *ddl) dropTableData(t table.Table) error {
// delete table data
err := d.delKeysWithPrefix(t.KeyPrefix())

View File

@ -280,19 +280,21 @@ func (d *ddl) runJob(t *meta.Meta, job *model.Job) error {
var err error
switch job.Type {
case model.ActionCreateSchema:
err = d.onSchemaCreate(t, job)
err = d.onCreateSchema(t, job)
case model.ActionDropSchema:
err = d.onSchemaDrop(t, job)
err = d.onDropSchema(t, job)
case model.ActionCreateTable:
err = d.onTableCreate(t, job)
err = d.onCreateTable(t, job)
case model.ActionDropTable:
err = d.onTableDrop(t, job)
err = d.onDropTable(t, job)
case model.ActionAddColumn:
err = d.onAddColumn(t, job)
case model.ActionDropColumn:
err = d.onDropColumn(t, job)
case model.ActionAddIndex:
err = d.onIndexCreate(t, job)
err = d.onCreateIndex(t, job)
case model.ActionDropIndex:
err = d.onIndexDrop(t, job)
err = d.onDropIndex(t, job)
case model.ActionAddConstraint:
log.Fatal("Doesn't support change constraint online")
case model.ActionDropConstraint:

View File

@ -125,13 +125,13 @@ func (m *Meta) parseTableID(key string) (int64, error) {
// GenAutoTableID adds step to the auto id of the table and returns the sum.
func (m *Meta) GenAutoTableID(dbID int64, tableID int64, step int64) (int64, error) {
// check db exists
// Check if db exists.
dbKey := m.dbKey(dbID)
if err := m.checkDBExists(dbKey); err != nil {
return 0, errors.Trace(err)
}
// check table exists
// Check if table exists.
tableKey := m.tableKey(tableID)
if err := m.checkTableExists(dbKey, tableKey); err != nil {
return 0, errors.Trace(err)
@ -239,14 +239,14 @@ func (m *Meta) UpdateDatabase(dbInfo *model.DBInfo) error {
// CreateTable creates a table with tableInfo in database.
func (m *Meta) CreateTable(dbID int64, tableInfo *model.TableInfo) error {
// first check db exists or not.
// Check if db exists.
dbKey := m.dbKey(dbID)
if err := m.checkDBExists(dbKey); err != nil {
return errors.Trace(err)
}
// Check if table exists.
tableKey := m.tableKey(tableInfo.ID)
// then check table exists or not
if err := m.checkTableNotExists(dbKey, tableKey); err != nil {
return errors.Trace(err)
}
@ -261,7 +261,7 @@ func (m *Meta) CreateTable(dbID int64, tableInfo *model.TableInfo) error {
// DropDatabase drops whole database.
func (m *Meta) DropDatabase(dbID int64) error {
// check if db exists.
// Check if db exists.
dbKey := m.dbKey(dbID)
if err := m.txn.HClear(dbKey); err != nil {
return errors.Trace(err)
@ -276,14 +276,14 @@ func (m *Meta) DropDatabase(dbID int64) error {
// DropTable drops table in database.
func (m *Meta) DropTable(dbID int64, tableID int64) error {
// first check db exists or not.
// Check if db exists.
dbKey := m.dbKey(dbID)
if err := m.checkDBExists(dbKey); err != nil {
return errors.Trace(err)
}
// Check if table exists.
tableKey := m.tableKey(tableID)
// then check table exists or not
if err := m.checkTableExists(dbKey, tableKey); err != nil {
return errors.Trace(err)
}
@ -301,15 +301,14 @@ func (m *Meta) DropTable(dbID int64, tableID int64) error {
// UpdateTable updates the table with table info.
func (m *Meta) UpdateTable(dbID int64, tableInfo *model.TableInfo) error {
// first check db exists or not.
// Check if db exists.
dbKey := m.dbKey(dbID)
if err := m.checkDBExists(dbKey); err != nil {
return errors.Trace(err)
}
// Check if table exists.
tableKey := m.tableKey(tableInfo.ID)
// then check table exists or not
if err := m.checkTableExists(dbKey, tableKey); err != nil {
return errors.Trace(err)
}
@ -320,7 +319,6 @@ func (m *Meta) UpdateTable(dbID int64, tableInfo *model.TableInfo) error {
}
err = m.txn.HSet(dbKey, tableKey, data)
return errors.Trace(err)
}
@ -390,7 +388,7 @@ func (m *Meta) GetDatabase(dbID int64) (*model.DBInfo, error) {
// GetTable gets the table value in database with tableID.
func (m *Meta) GetTable(dbID int64, tableID int64) (*model.TableInfo, error) {
// first check db exists or not.
// Check if db exists.
dbKey := m.dbKey(dbID)
if err := m.checkDBExists(dbKey); err != nil {
return nil, errors.Trace(err)

View File

@ -83,6 +83,7 @@ func (p *testFromSuit) SetUpSuite(c *C) {
Offset: 0,
DefaultValue: 0,
FieldType: *types.NewFieldType(mysql.TypeLonglong),
State: model.StatePublic,
},
},
{
@ -92,6 +93,7 @@ func (p *testFromSuit) SetUpSuite(c *C) {
Offset: 1,
DefaultValue: nil,
FieldType: *types.NewFieldType(mysql.TypeVarchar),
State: model.StatePublic,
},
},
}
@ -100,7 +102,8 @@ func (p *testFromSuit) SetUpSuite(c *C) {
var i int64
for i = 0; i < 10; i++ {
p.tbl.AddRecord(p, []interface{}{i * 10, "hello"})
_, err = p.tbl.AddRecord(p, []interface{}{i * 10, "hello"})
c.Assert(err, IsNil)
}
}

View File

@ -30,6 +30,7 @@ import (
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/stmt"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/util/errors2"
"github.com/pingcap/tidb/util/format"
"github.com/pingcap/tidb/util/types"
@ -92,6 +93,7 @@ func (s *InsertValues) execSelect(t table.Table, cols []*column.Col, ctx context
return nil, errors.Trace(err)
}
defer r.Close()
if len(r.GetFields()) != len(cols) {
return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(r.GetFields()))
}
@ -107,26 +109,29 @@ func (s *InsertValues) execSelect(t table.Table, cols []*column.Col, ctx context
if row == nil {
break
}
data0 := make([]interface{}, len(t.Cols()))
marked := make(map[int]struct{}, len(cols))
for i, d := range row.Data {
data0[cols[i].Offset] = d
marked[cols[i].Offset] = struct{}{}
currentRow := make([]interface{}, len(t.Cols()))
marked := make(map[int]struct{}, len(t.Cols()))
for i, data := range row.Data {
offset := cols[i].Offset
currentRow[offset] = data
marked[offset] = struct{}{}
}
if err = s.initDefaultValues(ctx, t, data0, marked); err != nil {
if err = s.initDefaultValues(ctx, t, currentRow, marked); err != nil {
return nil, errors.Trace(err)
}
if err = column.CastValues(ctx, data0, cols); err != nil {
if err = column.CastValues(ctx, currentRow, cols); err != nil {
return nil, errors.Trace(err)
}
if err = column.CheckNotNull(t.Cols(), data0); err != nil {
if err = column.CheckNotNull(t.Cols(), currentRow); err != nil {
return nil, errors.Trace(err)
}
var v interface{}
v, err = types.Clone(data0)
v, err = types.Clone(currentRow)
if err != nil {
return nil, errors.Trace(err)
}
@ -190,19 +195,19 @@ func (s *InsertValues) getColumns(tableCols []*column.Col) ([]*column.Col, error
return cols, nil
}
func (s *InsertValues) getDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) {
m := map[interface{}]interface{}{}
for _, v := range cols {
if value, ok, err := getDefaultValue(ctx, v); ok {
func (s *InsertValues) getColumnDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) {
defaultValMap := map[interface{}]interface{}{}
for _, col := range cols {
if value, ok, err := tables.GetColDefaultValue(ctx, &col.ColumnInfo); ok {
if err != nil {
return nil, errors.Trace(err)
}
m[v.Name.L] = value
defaultValMap[col.Name.L] = value
}
}
return m, nil
return defaultValMap, nil
}
func (s *InsertValues) fillValueList() error {
@ -227,6 +232,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
if err != nil {
return nil, errors.Trace(err)
}
cols, err := s.getColumns(t.Cols())
if err != nil {
return nil, errors.Trace(err)
@ -243,10 +249,11 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
return nil, errors.Trace(err)
}
m, err := s.getDefaultValues(ctx, t.Cols())
defaultValMap, err := s.getColumnDefaultValues(ctx, t.Cols())
if err != nil {
return nil, errors.Trace(err)
}
insertValueCount := len(s.Lists[0])
toUpdateColumns, err := getOnDuplicateUpdateColumns(s.OnDuplicate, t)
if err != nil {
@ -258,7 +265,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
return nil, errors.Trace(err)
}
row, err := s.getRow(ctx, t, cols, list, m)
row, err := s.fillRowData(ctx, t, cols, list, defaultValMap)
if err != nil {
return nil, errors.Trace(err)
}
@ -276,6 +283,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error)
if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) {
return nil, errors.Trace(err)
}
if err = execOnDuplicateUpdate(ctx, t, row, h, toUpdateColumns); err != nil {
return nil, errors.Trace(err)
}
@ -303,36 +311,40 @@ func (s *InsertValues) checkValueCount(insertValueCount, valueCount, num int, co
return nil
}
func (s *InsertValues) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, error) {
r := make([]interface{}, len(t.Cols()))
func (s *InsertValues) fillRowData(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, evalMap map[interface{}]interface{}) ([]interface{}, error) {
row := make([]interface{}, len(t.Cols()))
marked := make(map[int]struct{}, len(list))
for i, expr := range list {
// For "insert into t values (default)" Default Eval.
m[expression.ExprEvalDefaultName] = cols[i].Name.O
evalMap[expression.ExprEvalDefaultName] = cols[i].Name.O
val, err := expr.Eval(ctx, m)
val, err := expr.Eval(ctx, evalMap)
if err != nil {
return nil, errors.Trace(err)
}
r[cols[i].Offset] = val
marked[cols[i].Offset] = struct{}{}
offset := cols[i].Offset
row[offset] = val
marked[offset] = struct{}{}
}
// Clear last insert id.
variable.GetSessionVars(ctx).SetLastInsertID(0)
err := s.initDefaultValues(ctx, t, r, marked)
err := s.initDefaultValues(ctx, t, row, marked)
if err != nil {
return nil, errors.Trace(err)
}
if err = column.CastValues(ctx, r, cols); err != nil {
return nil, errors.Trace(err)
}
if err = column.CheckNotNull(t.Cols(), r); err != nil {
if err = column.CastValues(ctx, row, cols); err != nil {
return nil, errors.Trace(err)
}
return r, nil
if err = column.CheckNotNull(t.Cols(), row); err != nil {
return nil, errors.Trace(err)
}
return row, nil
}
func execOnDuplicateUpdate(ctx context.Context, t table.Table, row []interface{}, h int64, cols map[int]*expression.Assignment) error {
@ -396,7 +408,7 @@ func (s *InsertValues) initDefaultValues(ctx context.Context, t table.Table, row
variable.GetSessionVars(ctx).SetLastInsertID(uint64(id))
} else {
var value interface{}
value, _, err = getDefaultValue(ctx, c)
value, _, err = tables.GetColDefaultValue(ctx, &c.ColumnInfo)
if err != nil {
return errors.Trace(err)
}

View File

@ -61,6 +61,7 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error
if err != nil {
return nil, errors.Trace(err)
}
cols, err := s.getColumns(t.Cols())
if err != nil {
return nil, errors.Trace(err)
@ -71,28 +72,33 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error
if s.Sel != nil {
return s.execSelect(t, cols, ctx)
}
// Process `replace ... set x=y ...`
if err = s.fillValueList(); err != nil {
return nil, errors.Trace(err)
}
m, err := s.getDefaultValues(ctx, t.Cols())
evalMap, err := s.getColumnDefaultValues(ctx, t.Cols())
if err != nil {
return nil, errors.Trace(err)
}
replaceValueCount := len(s.Lists[0])
replaceValueCount := len(s.Lists[0])
for i, list := range s.Lists {
if err = s.checkValueCount(replaceValueCount, len(list), i, cols); err != nil {
return nil, errors.Trace(err)
}
row, err := s.getRow(ctx, t, cols, list, m)
row, err := s.fillRowData(ctx, t, cols, list, evalMap)
if err != nil {
return nil, errors.Trace(err)
}
h, err := t.AddRecord(ctx, row)
if err == nil {
continue
}
if err != nil && !errors2.ErrorEqual(err, kv.ErrKeyExists) {
return nil, errors.Trace(err)
}
@ -115,18 +121,20 @@ func replaceRow(ctx context.Context, t table.Table, handle int64, replaceRow []i
return errors.Trace(err)
}
result := 0
isReplace := false
touched := make([]bool, len(row))
touched := make(map[int]bool, len(row))
for i, val := range row {
v, err1 := types.Compare(val, replaceRow[i])
if err1 != nil {
return errors.Trace(err1)
result, err = types.Compare(val, replaceRow[i])
if err != nil {
return errors.Trace(err)
}
if v != 0 {
if result != 0 {
touched[i] = true
isReplace = true
}
}
if isReplace {
variable.GetSessionVars(ctx).AddAffectedRows(1)
if err = t.UpdateRecord(ctx, handle, row, replaceRow, touched); err != nil {

View File

@ -14,44 +14,11 @@
package stmts
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/table"
)
func getDefaultValue(ctx context.Context, c *column.Col) (interface{}, bool, error) {
// Check no default value flag.
if mysql.HasNoDefaultValueFlag(c.Flag) && c.Tp != mysql.TypeEnum {
return nil, false, errors.Errorf("Field '%s' doesn't have a default value", c.Name)
}
// Check and get timestamp/datetime default value.
if c.Tp == mysql.TypeTimestamp || c.Tp == mysql.TypeDatetime {
if c.DefaultValue == nil {
return nil, true, nil
}
value, err := expression.GetTimeValue(ctx, c.DefaultValue, c.Tp, c.Decimal)
if err != nil {
return nil, true, errors.Errorf("Field '%s' get default value fail - %s", c.Name, errors.Trace(err))
}
return value, true, nil
} else if c.Tp == mysql.TypeEnum {
// For enum type, if no default value and not null is set,
// the default value is the first element of the enum list
if c.DefaultValue == nil && mysql.HasNotNullFlag(c.Flag) {
return c.FieldType.Elems[0], true, nil
}
}
return c.DefaultValue, true, nil
}
func getTable(ctx context.Context, tableIdent table.Ident) (table.Table, error) {
full := tableIdent.Full(ctx)
return sessionctx.GetDomain(ctx).InfoSchema().TableByName(full.Schema, full.Name)

View File

@ -119,17 +119,17 @@ func getUpdateColumns(assignList []*expression.Assignment, fields []*field.Resul
}
func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Table,
updateColumns map[int]*expression.Assignment, m map[interface{}]interface{},
updateColumns map[int]*expression.Assignment, evalMap map[interface{}]interface{},
offset int, onDuplicateUpdate bool) error {
if err := t.LockRow(ctx, h, true); err != nil {
return errors.Trace(err)
}
oldData := make([]interface{}, len(t.Cols()))
touched := make([]bool, len(t.Cols()))
copy(oldData, data)
cols := t.Cols()
oldData := data
newData := make([]interface{}, len(cols))
touched := make(map[int]bool, len(cols))
copy(newData, oldData)
assignExists := false
for i, asgn := range updateColumns {
@ -137,47 +137,49 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl
// The assign expression is for another table, not this.
continue
}
val, err := asgn.Expr.Eval(ctx, m)
val, err := asgn.Expr.Eval(ctx, evalMap)
if err != nil {
return err
return errors.Trace(err)
}
colIndex := i - offset
touched[colIndex] = true
data[colIndex] = val
newData[colIndex] = val
assignExists = true
}
// no assign list for this table, no need to update.
// If no assign list for this table, no need to update.
if !assignExists {
return nil
}
// Check whether new value is valid.
if err := column.CastValues(ctx, data, t.Cols()); err != nil {
return err
if err := column.CastValues(ctx, newData, cols); err != nil {
return errors.Trace(err)
}
if err := column.CheckNotNull(t.Cols(), data); err != nil {
return err
if err := column.CheckNotNull(cols, newData); err != nil {
return errors.Trace(err)
}
// If row is not changed, we should do nothing.
rowChanged := false
for i, d := range data {
for i := range oldData {
if !touched[i] {
continue
}
od := oldData[i]
n, err := types.Compare(d, od)
n, err := types.Compare(newData[i], oldData[i])
if err != nil {
return errors.Trace(err)
}
if n != 0 {
rowChanged = true
break
}
}
if !rowChanged {
// See: https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS
if variable.GetSessionVars(ctx).ClientCapability&mysql.ClientFoundRows > 0 {
@ -187,10 +189,11 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl
}
// Update record to new value and update index.
err := t.UpdateRecord(ctx, h, oldData, data, touched)
err := t.UpdateRecord(ctx, h, oldData, newData, touched)
if err != nil {
return errors.Trace(err)
}
// Record affected rows.
if !onDuplicateUpdate {
variable.GetSessionVars(ctx).AddAffectedRows(1)
@ -198,6 +201,7 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl
variable.GetSessionVars(ctx).AddAffectedRows(2)
}
return nil
}
@ -253,17 +257,13 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
return nil, errors.Trace(err)
}
defer p.Close()
updatedRowKeys := make(map[string]bool)
// Get table alias map.
fs := p.GetFields()
columns, err0 := getUpdateColumns(s.List, fs)
if err0 != nil {
return nil, errors.Trace(err0)
columns, err := getUpdateColumns(s.List, fs)
if err != nil {
return nil, errors.Trace(err)
}
m := map[interface{}]interface{}{}
var records []*plan.Row
for {
row, err1 := p.Next(ctx)
@ -280,15 +280,17 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
records = append(records, row)
}
evalMap := map[interface{}]interface{}{}
updatedRowKeys := make(map[string]bool)
for _, row := range records {
rowData := row.Data
// Set ExprEvalIdentReferFunc
m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
// Set ExprEvalIdentReferFunc.
evalMap[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) {
return rowData[index], nil
}
// Update rows
// Update rows.
offset := 0
for _, entry := range row.RowKeys {
tbl := entry.Tbl
@ -302,13 +304,14 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
// Each matching row is updated once, even if it matches the conditions multiple times.
continue
}
// Update row
handle, err2 := util.DecodeHandleFromRowKey(k)
if err2 != nil {
return nil, errors.Trace(err2)
}
err2 = updateRecord(ctx, handle, data, tbl, columns, m, lastOffset, false)
err2 = updateRecord(ctx, handle, data, tbl, columns, evalMap, lastOffset, false)
if err2 != nil {
return nil, errors.Trace(err2)
}

View File

@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/sessionctx/db"
@ -56,7 +57,7 @@ type Table interface {
// TableName returns table name.
TableName() model.CIStr
// Cols returns the columns of the table.
// Cols returns the columns of the table which is used in select.
Cols() []*column.Col
// Indices returns the indices of the table.
@ -87,7 +88,7 @@ type Table interface {
AddRecord(ctx context.Context, r []interface{}) (recordID int64, err error)
// UpdateRecord updates a row in the table.
UpdateRecord(ctx context.Context, h int64, currData []interface{}, newData []interface{}, touched []bool) error
UpdateRecord(ctx context.Context, h int64, currData []interface{}, newData []interface{}, touched map[int]bool) error
// TableID returns the ID of the table.
TableID() int64
@ -107,6 +108,10 @@ type Table interface {
// LockRow locks a row.
// If update is true, set row lock key to current txn.
LockRow(ctx context.Context, h int64, update bool) error
// SetColValue sets the column value.
// If the column is untouched, we don't need to do this.
SetColValue(txn kv.Transaction, key []byte, data interface{}) error
}
// TableFromMeta builds a table.Table from *model.TableInfo.

View File

@ -44,11 +44,13 @@ type Table struct {
Name model.CIStr
Columns []*column.Col
indices []*column.IndexedCol
recordPrefix string
indexPrefix string
alloc autoid.Allocator
state model.SchemaState
publicColumns []*column.Col
writableColumns []*column.Col
indices []*column.IndexedCol
recordPrefix string
indexPrefix string
alloc autoid.Allocator
state model.SchemaState
}
// TableFromMeta creates a Table instance from model.TableInfo.
@ -56,8 +58,8 @@ func TableFromMeta(alloc autoid.Allocator, tblInfo *model.TableInfo) table.Table
t := NewTable(tblInfo.ID, tblInfo.Name.O, nil, alloc)
for _, colInfo := range tblInfo.Columns {
c := column.Col{ColumnInfo: *colInfo}
t.Columns = append(t.Columns, &c)
col := &column.Col{ColumnInfo: *colInfo}
t.Columns = append(t.Columns, col)
}
for _, idxInfo := range tblInfo.Indices {
@ -84,6 +86,9 @@ func NewTable(tableID int64, tableName string, cols []*column.Col, alloc autoid.
Columns: cols,
state: model.StatePublic,
}
t.publicColumns = t.Cols()
t.writableColumns = t.writableCols()
return t
}
@ -129,7 +134,35 @@ func (t *Table) Meta() *model.TableInfo {
// Cols implements table.Table Cols interface.
func (t *Table) Cols() []*column.Col {
return t.Columns
if len(t.publicColumns) > 0 {
return t.publicColumns
}
t.publicColumns = make([]*column.Col, 0, len(t.Columns))
for _, col := range t.Columns {
if col.State == model.StatePublic {
t.publicColumns = append(t.publicColumns, col)
}
}
return t.publicColumns
}
func (t *Table) writableCols() []*column.Col {
if len(t.writableColumns) > 0 {
return t.writableColumns
}
t.writableColumns = make([]*column.Col, 0, len(t.Columns))
for _, col := range t.Columns {
if col.State == model.StateDeleteOnly {
continue
}
t.writableColumns = append(t.writableColumns, col)
}
return t.writableColumns
}
func (t *Table) unflatten(rec interface{}, col *column.Col) (interface{}, error) {
@ -237,43 +270,47 @@ func (t *Table) Truncate(ctx context.Context) error {
}
// UpdateRecord implements table.Table UpdateRecord interface.
func (t *Table) UpdateRecord(ctx context.Context, h int64, currData []interface{}, newData []interface{}, touched []bool) error {
// if they are not set, and other data are changed, they will be updated by current timestamp too.
// set on update value
err := t.setOnUpdateData(ctx, touched, newData)
func (t *Table) UpdateRecord(ctx context.Context, h int64, oldData []interface{}, newData []interface{}, touched map[int]bool) error {
// We should check whether this table has on update column which state is write only.
currentData := make([]interface{}, len(t.writableCols()))
copy(currentData, newData)
// If they are not set, and other data are changed, they will be updated by current timestamp too.
err := t.setOnUpdateData(ctx, touched, currentData)
if err != nil {
return errors.Trace(err)
}
// set new value
if err := t.setNewData(ctx, h, newData); err != nil {
if err := t.setNewData(ctx, h, touched, currentData); err != nil {
return errors.Trace(err)
}
// rebuild index
if err := t.rebuildIndices(ctx, h, touched, currData, newData); err != nil {
if err := t.rebuildIndices(ctx, h, touched, oldData, currentData); err != nil {
return errors.Trace(err)
}
return nil
}
func (t *Table) setOnUpdateData(ctx context.Context, touched []bool, data []interface{}) error {
ucols := column.FindOnUpdateCols(t.Cols())
for _, c := range ucols {
if !touched[c.Offset] {
v, err := expression.GetTimeValue(ctx, expression.CurrentTimestamp, c.Tp, c.Decimal)
func (t *Table) setOnUpdateData(ctx context.Context, touched map[int]bool, data []interface{}) error {
ucols := column.FindOnUpdateCols(t.writableCols())
for _, col := range ucols {
if !touched[col.Offset] {
value, err := expression.GetTimeValue(ctx, expression.CurrentTimestamp, col.Tp, col.Decimal)
if err != nil {
return errors.Trace(err)
}
data[c.Offset] = v
touched[c.Offset] = true
data[col.Offset] = value
touched[col.Offset] = true
}
}
return nil
}
// SetColValue sets the column value.
// If the column untouched, we don't need to do this.
// SetColValue implements table.Table SetColValue interface.
func (t *Table) SetColValue(txn kv.Transaction, key []byte, data interface{}) error {
v, err := t.EncodeValue(data)
if err != nil {
@ -285,21 +322,27 @@ func (t *Table) SetColValue(txn kv.Transaction, key []byte, data interface{}) er
return nil
}
func (t *Table) setNewData(ctx context.Context, h int64, data []interface{}) error {
func (t *Table) setNewData(ctx context.Context, h int64, touched map[int]bool, data []interface{}) error {
txn, err := ctx.GetTxn(false)
if err != nil {
return errors.Trace(err)
}
for _, col := range t.Cols() {
if !touched[col.Offset] {
continue
}
k := t.RecordKey(h, col)
if err := t.SetColValue(txn, k, data[col.Offset]); err != nil {
return errors.Trace(err)
}
}
return nil
}
func (t *Table) rebuildIndices(ctx context.Context, h int64, touched []bool, oldData, newData []interface{}) error {
func (t *Table) rebuildIndices(ctx context.Context, h int64, touched map[int]bool, oldData []interface{}, newData []interface{}) error {
for _, idx := range t.Indices() {
idxTouched := false
for _, ic := range idx.Columns {
@ -325,6 +368,7 @@ func (t *Table) rebuildIndices(ctx context.Context, h int64, touched []bool, old
if err != nil {
return errors.Trace(err)
}
if err := t.BuildIndexForRow(ctx, h, newVs, idx); err != nil {
return errors.Trace(err)
}
@ -376,13 +420,25 @@ func (t *Table) AddRecord(ctx context.Context, r []interface{}) (recordID int64,
return 0, errors.Trace(err)
}
// column key -> column value
for _, c := range t.Cols() {
k := t.RecordKey(recordID, c)
if err := t.SetColValue(txn, k, r[c.Offset]); err != nil {
// Set public and write only column value.
for _, col := range t.writableCols() {
var value interface{}
key := t.RecordKey(recordID, col)
if col.State == model.StateWriteOnly {
value, _, err = GetColDefaultValue(ctx, &col.ColumnInfo)
if err != nil {
return 0, errors.Trace(err)
}
} else {
value = r[col.Offset]
}
err = t.SetColValue(txn, key, value)
if err != nil {
return 0, errors.Trace(err)
}
}
variable.GetSessionVars(ctx).AddAffectedRows(1)
return recordID, nil
}
@ -412,20 +468,25 @@ func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*column.Col) ([
if err != nil {
return nil, errors.Trace(err)
}
// use the length of t.Cols() for alignment
v := make([]interface{}, len(t.Cols()))
for _, c := range cols {
k := t.RecordKey(h, c)
for _, col := range cols {
if col.State != model.StatePublic {
return nil, errors.Errorf("Cannot use none public column - %v", cols)
}
k := t.RecordKey(h, col)
data, err := txn.Get([]byte(k))
if err != nil {
return nil, errors.Trace(err)
}
val, err := t.DecodeValue(data, c)
val, err := t.DecodeValue(data, col)
if err != nil {
return nil, errors.Trace(err)
}
v[c.Offset] = val
v[col.Offset] = val
}
return v, nil
}
@ -471,10 +532,16 @@ func (t *Table) RemoveRow(ctx context.Context, h int64) error {
return errors.Trace(err)
}
// Remove row's colume one by one
for _, col := range t.Cols() {
for _, col := range t.Columns {
k := t.RecordKey(h, col)
err := txn.Delete([]byte(k))
if err != nil {
if col.State != model.StatePublic && errors2.ErrorEqual(err, kv.ErrNotExist) {
// If the column is not in public state, we may have not added the column,
// or already deleted the column, so skip ErrNotExist error.
continue
}
return errors.Trace(err)
}
}
@ -600,6 +667,36 @@ func (t *Table) AllocAutoID() (int64, error) {
return t.alloc.Alloc(t.ID)
}
// GetColDefaultValue gets default value of the column.
func GetColDefaultValue(ctx context.Context, col *model.ColumnInfo) (interface{}, bool, error) {
// Check no default value flag.
if mysql.HasNoDefaultValueFlag(col.Flag) && col.Tp != mysql.TypeEnum {
return nil, false, errors.Errorf("Field '%s' doesn't have a default value", col.Name)
}
// Check and get timestamp/datetime default value.
if col.Tp == mysql.TypeTimestamp || col.Tp == mysql.TypeDatetime {
if col.DefaultValue == nil {
return nil, true, nil
}
value, err := expression.GetTimeValue(ctx, col.DefaultValue, col.Tp, col.Decimal)
if err != nil {
return nil, true, errors.Errorf("Field '%s' get default value fail - %s", col.Name, errors.Trace(err))
}
return value, true, nil
} else if col.Tp == mysql.TypeEnum {
// For enum type, if no default value and not null is set,
// the default value is the first element of the enum list
if col.DefaultValue == nil && mysql.HasNotNullFlag(col.Flag) {
return col.FieldType.Elems[0], true, nil
}
}
return col.DefaultValue, true, nil
}
func init() {
table.TableFromMeta = TableFromMeta
}

View File

@ -82,7 +82,7 @@ func (ts *testSuite) TestBasic(c *C) {
_, err = tb.AddRecord(ctx, []interface{}{2, "abc"})
c.Assert(err, NotNil)
c.Assert(tb.UpdateRecord(ctx, rid, []interface{}{1, "abc"}, []interface{}{1, "cba"}, []bool{false, true}), IsNil)
c.Assert(tb.UpdateRecord(ctx, rid, []interface{}{1, "abc"}, []interface{}{1, "cba"}, map[int]bool{0: false, 1: true}), IsNil)
tb.IterRecords(ctx, tb.FirstKey(), tb.Cols(), func(h int64, data []interface{}, cols []*column.Col) (bool, error) {
return true, nil