stmt: update affected rows and test

This commit is contained in:
xia
2015-10-24 16:19:10 +08:00
parent bacb489963
commit b4562dc8f5
2 changed files with 32 additions and 81 deletions

View File

@ -14,17 +14,15 @@
package stmts
import (
"strings"
"github.com/juju/errors"
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/rset"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/errors2"
"github.com/pingcap/tidb/util/format"
"github.com/pingcap/tidb/util/types"
)
@ -99,97 +97,45 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error
if err != nil {
return nil, errors.Trace(err)
}
h, err := t.AddRecord(ctx, row)
if err == nil {
continue
}
if err != nil && !errors2.ErrorEqual(err, kv.ErrKeyExists) {
return nil, errors.Trace(err)
}
if err = removeExistRow(ctx, t, row); err != nil {
return nil, errors.Trace(err)
}
if _, err = t.AddRecord(ctx, row); err != nil {
// While the insertion fails because a duplicate-key error occurs for a primary key or unique index,
// a storage engine may perform the REPLACE as an update rather than a delete plus insert.
// See: http://dev.mysql.com/doc/refman/5.7/en/replace.html.
if err = replaceRow(ctx, t, h, row); err != nil {
return nil, errors.Trace(err)
}
variable.GetSessionVars(ctx).AddAffectedRows(1)
}
return nil, nil
}
func removeExistRow(ctx context.Context, t table.Table, replaceRow []interface{}) error {
indices := make([]*column.IndexedCol, 0, len(t.Indices()))
for _, idx := range t.Indices() {
// TODO: handles the idx is composite primary key the affected rows.
if idx.Unique || idx.Primary && len(idx.Columns) >= 1 {
indices = append(indices, idx)
}
}
if len(indices) == 0 {
return nil
}
txn, err := ctx.GetTxn(false)
func replaceRow(ctx context.Context, t table.Table, handle int64, replaceRow []interface{}) error {
row, err := t.Row(ctx, handle)
if err != nil {
return errors.Trace(err)
}
it, err := txn.Seek([]byte(t.FirstKey()))
if err != nil {
return errors.Trace(err)
}
defer it.Close()
prefix := t.KeyPrefix()
for it.Valid() && strings.HasPrefix(it.Key(), prefix) {
handle, err0 := util.DecodeHandleFromRowKey(it.Key())
if err0 != nil {
return errors.Trace(err0)
for i, val := range row {
v, err := types.Compare(val, replaceRow[i])
if err != nil {
return errors.Trace(err)
}
row, err0 := t.Row(ctx, handle)
if err0 != nil {
return errors.Trace(err0)
}
for _, idx := range indices {
v, err1 := compareIndex(t.Cols(), idx, row, replaceRow)
if err1 != nil {
return errors.Trace(err1)
if v != 0 {
touched := make([]bool, len(row))
for i := 0; i < len(touched); i++ {
touched[i] = true
}
if v != 0 {
continue
}
if err = removeRow(ctx, t, handle, row); err != nil {
return errors.Trace(err)
}
}
rk := t.RecordKey(handle, nil)
if it, err0 = kv.NextUntil(it, util.RowKeyPrefixFilter(rk)); err0 != nil {
return errors.Trace(err0)
variable.GetSessionVars(ctx).AddAffectedRows(1)
return t.UpdateRecord(ctx, handle, row, replaceRow, touched)
}
}
return nil
}
// compareIndex returns an integer comparing the idx old with new.
// old > new -> 1
// old = new -> 0
// old < new -> -1
func compareIndex(cols []*column.Col, idx *column.IndexedCol, row, replaceRow []interface{}) (v int, err error) {
nulls := 0
for _, idxCol := range idx.Columns {
col := column.FindCol(cols, idxCol.Name.L)
if col == nil {
return 0, errors.Errorf("No such column: %v", idx)
}
v, err = types.Compare(row[col.Offset], replaceRow[col.Offset])
if err != nil {
return 0, errors.Trace(err)
}
if v != 0 {
break
}
if row[col.Offset] == nil {
nulls++
}
}
if nulls == len(idx.Columns) {
v = -1
}
return v, nil
}

View File

@ -93,6 +93,11 @@ func (s *testStmtSuite) TestReplace(c *C) {
ret := mustExec(c, s.testDB, replaceUniqueIndexSQL)
rows, err := ret.RowsAffected()
c.Assert(err, IsNil)
c.Assert(rows, Equals, int64(1))
replaceUniqueIndexSQL = `replace into replace_test_3 set c1=1, c2=1;`
ret = mustExec(c, s.testDB, replaceUniqueIndexSQL)
rows, err = ret.RowsAffected()
c.Assert(err, IsNil)
c.Assert(rows, Equals, int64(2))
replaceUniqueIndexSQL = `replace into replace_test_3 set c2=NULL;`
@ -119,5 +124,5 @@ func (s *testStmtSuite) TestReplace(c *C) {
ret = mustExec(c, s.testDB, replacePrimaryKeySQL)
rows, err = ret.RowsAffected()
c.Assert(err, IsNil)
c.Assert(rows, Equals, int64(2))
c.Assert(rows, Equals, int64(1))
}