diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index 486f79b8c0..e5890c8523 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -14,17 +14,15 @@ package stmts import ( - "strings" - "github.com/juju/errors" - "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/rset" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/errors2" "github.com/pingcap/tidb/util/format" "github.com/pingcap/tidb/util/types" ) @@ -99,97 +97,45 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error if err != nil { return nil, errors.Trace(err) } + h, err := t.AddRecord(ctx, row) + if err == nil { + continue + } + if err != nil && !errors2.ErrorEqual(err, kv.ErrKeyExists) { + return nil, errors.Trace(err) + } - if err = removeExistRow(ctx, t, row); err != nil { - return nil, errors.Trace(err) - } - if _, err = t.AddRecord(ctx, row); err != nil { + // While the insertion fails because a duplicate-key error occurs for a primary key or unique index, + // a storage engine may perform the REPLACE as an update rather than a delete plus insert. + // See: http://dev.mysql.com/doc/refman/5.7/en/replace.html. + if err = replaceRow(ctx, t, h, row); err != nil { return nil, errors.Trace(err) } + variable.GetSessionVars(ctx).AddAffectedRows(1) } return nil, nil } -func removeExistRow(ctx context.Context, t table.Table, replaceRow []interface{}) error { - indices := make([]*column.IndexedCol, 0, len(t.Indices())) - for _, idx := range t.Indices() { - // TODO: handles the idx is composite primary key the affected rows. - if idx.Unique || idx.Primary && len(idx.Columns) >= 1 { - indices = append(indices, idx) - } - } - if len(indices) == 0 { - return nil - } - - txn, err := ctx.GetTxn(false) +func replaceRow(ctx context.Context, t table.Table, handle int64, replaceRow []interface{}) error { + row, err := t.Row(ctx, handle) if err != nil { return errors.Trace(err) } - it, err := txn.Seek([]byte(t.FirstKey())) - if err != nil { - return errors.Trace(err) - } - defer it.Close() - - prefix := t.KeyPrefix() - for it.Valid() && strings.HasPrefix(it.Key(), prefix) { - handle, err0 := util.DecodeHandleFromRowKey(it.Key()) - if err0 != nil { - return errors.Trace(err0) + for i, val := range row { + v, err := types.Compare(val, replaceRow[i]) + if err != nil { + return errors.Trace(err) } - row, err0 := t.Row(ctx, handle) - if err0 != nil { - return errors.Trace(err0) - } - for _, idx := range indices { - v, err1 := compareIndex(t.Cols(), idx, row, replaceRow) - if err1 != nil { - return errors.Trace(err1) + if v != 0 { + touched := make([]bool, len(row)) + for i := 0; i < len(touched); i++ { + touched[i] = true } - if v != 0 { - continue - } - if err = removeRow(ctx, t, handle, row); err != nil { - return errors.Trace(err) - } - } - - rk := t.RecordKey(handle, nil) - if it, err0 = kv.NextUntil(it, util.RowKeyPrefixFilter(rk)); err0 != nil { - return errors.Trace(err0) + variable.GetSessionVars(ctx).AddAffectedRows(1) + return t.UpdateRecord(ctx, handle, row, replaceRow, touched) } } return nil } - -// compareIndex returns an integer comparing the idx old with new. -// old > new -> 1 -// old = new -> 0 -// old < new -> -1 -func compareIndex(cols []*column.Col, idx *column.IndexedCol, row, replaceRow []interface{}) (v int, err error) { - nulls := 0 - for _, idxCol := range idx.Columns { - col := column.FindCol(cols, idxCol.Name.L) - if col == nil { - return 0, errors.Errorf("No such column: %v", idx) - } - v, err = types.Compare(row[col.Offset], replaceRow[col.Offset]) - if err != nil { - return 0, errors.Trace(err) - } - if v != 0 { - break - } - if row[col.Offset] == nil { - nulls++ - } - } - if nulls == len(idx.Columns) { - v = -1 - } - - return v, nil -} diff --git a/stmt/stmts/replace_test.go b/stmt/stmts/replace_test.go index fed415c636..fb22349161 100644 --- a/stmt/stmts/replace_test.go +++ b/stmt/stmts/replace_test.go @@ -93,6 +93,11 @@ func (s *testStmtSuite) TestReplace(c *C) { ret := mustExec(c, s.testDB, replaceUniqueIndexSQL) rows, err := ret.RowsAffected() c.Assert(err, IsNil) + c.Assert(rows, Equals, int64(1)) + replaceUniqueIndexSQL = `replace into replace_test_3 set c1=1, c2=1;` + ret = mustExec(c, s.testDB, replaceUniqueIndexSQL) + rows, err = ret.RowsAffected() + c.Assert(err, IsNil) c.Assert(rows, Equals, int64(2)) replaceUniqueIndexSQL = `replace into replace_test_3 set c2=NULL;` @@ -119,5 +124,5 @@ func (s *testStmtSuite) TestReplace(c *C) { ret = mustExec(c, s.testDB, replacePrimaryKeySQL) rows, err = ret.RowsAffected() c.Assert(err, IsNil) - c.Assert(rows, Equals, int64(2)) + c.Assert(rows, Equals, int64(1)) }