stmt: update affected rows and test
This commit is contained in:
@ -14,17 +14,15 @@
|
||||
package stmts
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/column"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/plan"
|
||||
"github.com/pingcap/tidb/rset"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
"github.com/pingcap/tidb/table"
|
||||
"github.com/pingcap/tidb/util"
|
||||
"github.com/pingcap/tidb/util/errors2"
|
||||
"github.com/pingcap/tidb/util/format"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
@ -99,97 +97,45 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
h, err := t.AddRecord(ctx, row)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
if err != nil && !errors2.ErrorEqual(err, kv.ErrKeyExists) {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
if err = removeExistRow(ctx, t, row); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
if _, err = t.AddRecord(ctx, row); err != nil {
|
||||
// While the insertion fails because a duplicate-key error occurs for a primary key or unique index,
|
||||
// a storage engine may perform the REPLACE as an update rather than a delete plus insert.
|
||||
// See: http://dev.mysql.com/doc/refman/5.7/en/replace.html.
|
||||
if err = replaceRow(ctx, t, h, row); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
variable.GetSessionVars(ctx).AddAffectedRows(1)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func removeExistRow(ctx context.Context, t table.Table, replaceRow []interface{}) error {
|
||||
indices := make([]*column.IndexedCol, 0, len(t.Indices()))
|
||||
for _, idx := range t.Indices() {
|
||||
// TODO: handles the idx is composite primary key the affected rows.
|
||||
if idx.Unique || idx.Primary && len(idx.Columns) >= 1 {
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
}
|
||||
if len(indices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
txn, err := ctx.GetTxn(false)
|
||||
func replaceRow(ctx context.Context, t table.Table, handle int64, replaceRow []interface{}) error {
|
||||
row, err := t.Row(ctx, handle)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
it, err := txn.Seek([]byte(t.FirstKey()))
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
prefix := t.KeyPrefix()
|
||||
for it.Valid() && strings.HasPrefix(it.Key(), prefix) {
|
||||
handle, err0 := util.DecodeHandleFromRowKey(it.Key())
|
||||
if err0 != nil {
|
||||
return errors.Trace(err0)
|
||||
for i, val := range row {
|
||||
v, err := types.Compare(val, replaceRow[i])
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
row, err0 := t.Row(ctx, handle)
|
||||
if err0 != nil {
|
||||
return errors.Trace(err0)
|
||||
}
|
||||
for _, idx := range indices {
|
||||
v, err1 := compareIndex(t.Cols(), idx, row, replaceRow)
|
||||
if err1 != nil {
|
||||
return errors.Trace(err1)
|
||||
if v != 0 {
|
||||
touched := make([]bool, len(row))
|
||||
for i := 0; i < len(touched); i++ {
|
||||
touched[i] = true
|
||||
}
|
||||
if v != 0 {
|
||||
continue
|
||||
}
|
||||
if err = removeRow(ctx, t, handle, row); err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
}
|
||||
|
||||
rk := t.RecordKey(handle, nil)
|
||||
if it, err0 = kv.NextUntil(it, util.RowKeyPrefixFilter(rk)); err0 != nil {
|
||||
return errors.Trace(err0)
|
||||
variable.GetSessionVars(ctx).AddAffectedRows(1)
|
||||
return t.UpdateRecord(ctx, handle, row, replaceRow, touched)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// compareIndex returns an integer comparing the idx old with new.
|
||||
// old > new -> 1
|
||||
// old = new -> 0
|
||||
// old < new -> -1
|
||||
func compareIndex(cols []*column.Col, idx *column.IndexedCol, row, replaceRow []interface{}) (v int, err error) {
|
||||
nulls := 0
|
||||
for _, idxCol := range idx.Columns {
|
||||
col := column.FindCol(cols, idxCol.Name.L)
|
||||
if col == nil {
|
||||
return 0, errors.Errorf("No such column: %v", idx)
|
||||
}
|
||||
v, err = types.Compare(row[col.Offset], replaceRow[col.Offset])
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
if v != 0 {
|
||||
break
|
||||
}
|
||||
if row[col.Offset] == nil {
|
||||
nulls++
|
||||
}
|
||||
}
|
||||
if nulls == len(idx.Columns) {
|
||||
v = -1
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
@ -93,6 +93,11 @@ func (s *testStmtSuite) TestReplace(c *C) {
|
||||
ret := mustExec(c, s.testDB, replaceUniqueIndexSQL)
|
||||
rows, err := ret.RowsAffected()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(rows, Equals, int64(1))
|
||||
replaceUniqueIndexSQL = `replace into replace_test_3 set c1=1, c2=1;`
|
||||
ret = mustExec(c, s.testDB, replaceUniqueIndexSQL)
|
||||
rows, err = ret.RowsAffected()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(rows, Equals, int64(2))
|
||||
|
||||
replaceUniqueIndexSQL = `replace into replace_test_3 set c2=NULL;`
|
||||
@ -119,5 +124,5 @@ func (s *testStmtSuite) TestReplace(c *C) {
|
||||
ret = mustExec(c, s.testDB, replacePrimaryKeySQL)
|
||||
rows, err = ret.RowsAffected()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(rows, Equals, int64(2))
|
||||
c.Assert(rows, Equals, int64(1))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user