diff --git a/kv/error.go b/kv/error.go index 23523a1e16..bd6a1923ea 100644 --- a/kv/error.go +++ b/kv/error.go @@ -26,8 +26,6 @@ var ( ErrClosed = errors.New("Error: Transaction already closed") // ErrNotExist is used when try to get an entry with an unexist key from KV store. ErrNotExist = errors.New("Error: key not exist") - // ErrKeyExists is used when try to put an entry to KV store. - ErrKeyExists = errors.New("Error: key already exist") // ErrConditionNotMatch is used when condition is not met. ErrConditionNotMatch = errors.New("Error: Condition not match") // ErrLockConflict is used when try to lock an already locked key. diff --git a/kv/index_iter.go b/kv/index_iter.go index 303dc288b1..c6db0f54a3 100644 --- a/kv/index_iter.go +++ b/kv/index_iter.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/juju/errors" + "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/util/codec" ) @@ -193,7 +194,7 @@ func (c *kvIndex) Create(rm RetrieverMutator, indexedValues []interface{}, h int return errors.Trace(err) } - return errors.Trace(ErrKeyExists) + return errors.Trace(terror.ErrKeyExists) } // Delete removes the entry for handle h and indexdValues from KV index. @@ -282,7 +283,7 @@ func (c *kvIndex) Exist(rm RetrieverMutator, indexedValues []interface{}, h int6 } if handle != h { - return true, handle, errors.Trace(ErrKeyExists) + return true, handle, errors.Trace(terror.ErrKeyExists) } return true, handle, nil diff --git a/kv/union_store.go b/kv/union_store.go index d286e5786c..8c0e1d6f2d 100644 --- a/kv/union_store.go +++ b/kv/union_store.go @@ -19,6 +19,7 @@ import ( "github.com/juju/errors" "github.com/ngaut/pool" + "github.com/pingcap/tidb/terror" ) // UnionStore is a store that wraps a snapshot for read and a BufferStore for buffered write. @@ -212,7 +213,7 @@ func (us *unionStore) CheckLazyConditionPairs() error { for ; it.Valid(); it.Next() { if len(it.Value()) == 0 { if _, exist := values[it.Key()]; exist { - return errors.Trace(ErrKeyExists) + return errors.Trace(terror.ErrKeyExists) } } else { if bytes.Compare(values[it.Key()], it.Value()) != 0 { diff --git a/plan/plans/index.go b/plan/plans/index.go index 3ea55f87cb..645efa44bf 100644 --- a/plan/plans/index.go +++ b/plan/plans/index.go @@ -436,12 +436,12 @@ func (r *indexPlan) pointLookup(ctx context.Context, val interface{}) (*plan.Row } var exist bool var h int64 - // We expect a kv.ErrKeyExists Error because we pass -1 as the handle which is not equal to the existed handle. + // We expect a terror.ErrKeyExists Error because we pass -1 as the handle which is not equal to the existed handle. exist, h, err = r.idx.Exist(txn, []interface{}{val}, -1) if !exist { return nil, errors.Trace(err) } - if terror.ErrorNotEqual(kv.ErrKeyExists, err) { + if terror.ErrorNotEqual(terror.ErrKeyExists, err) { return nil, errors.Trace(err) } var row *plan.Row diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index 7c298f1c06..831b990e5c 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -215,7 +215,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) continue } - if len(s.OnDuplicate) == 0 || !terror.ErrorEqual(err, kv.ErrKeyExists) { + if len(s.OnDuplicate) == 0 || !terror.ErrorEqual(err, terror.ErrKeyExists) { return nil, errors.Trace(err) } if err = execOnDuplicateUpdate(ctx, t, row, h, toUpdateColumns); err != nil { diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index a6018f6771..07793a08d0 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -16,7 +16,6 @@ package stmts import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" @@ -83,7 +82,7 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error if err == nil { continue } - if err != nil && !terror.ErrorEqual(err, kv.ErrKeyExists) { + if err != nil && !terror.ErrorEqual(err, terror.ErrKeyExists) { return nil, errors.Trace(err) } diff --git a/store/localstore/kv.go b/store/localstore/kv.go index e6e77683fe..8712acb8c9 100644 --- a/store/localstore/kv.go +++ b/store/localstore/kv.go @@ -101,7 +101,8 @@ func (s *dbStore) CommitTxn(txn *dbTxn) error { return errors.Trace(err) } -func (s *dbStore) seekWorker(seekCh chan *command) { +func (s *dbStore) seekWorker(wg *sync.WaitGroup, seekCh chan *command) { + defer wg.Done() for { var pending []*command select { @@ -131,8 +132,10 @@ func (s *dbStore) seekWorker(seekCh chan *command) { func (s *dbStore) scheduler() { closed := false seekCh := make(chan *command, 1000) + wgSeekWorkers := &sync.WaitGroup{} + wgSeekWorkers.Add(maxSeekWorkers) for i := 0; i < maxSeekWorkers; i++ { - go s.seekWorker(seekCh) + go s.seekWorker(wgSeekWorkers, seekCh) } for { @@ -150,9 +153,10 @@ func (s *dbStore) scheduler() { } case <-s.closeCh: closed = true - s.wg.Done() // notify seek worker to exit close(seekCh) + wgSeekWorkers.Wait() + s.wg.Done() } } } diff --git a/table/tables/tables.go b/table/tables/tables.go index 920636a845..267f65d82e 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -430,7 +430,7 @@ func (t *Table) AddRecord(ctx context.Context, r []interface{}, h int64) (record } colVals, _ := v.FetchValues(r) if err = v.X.Create(bs, colVals, recordID); err != nil { - if terror.ErrorEqual(err, kv.ErrKeyExists) { + if terror.ErrorEqual(err, terror.ErrKeyExists) { // Get the duplicate row handle // For insert on duplicate syntax, we should update the row iter, _, err1 := v.X.Seek(bs, colVals) diff --git a/terror/terror.go b/terror/terror.go index c4818d3b76..0208aea6d6 100644 --- a/terror/terror.go +++ b/terror/terror.go @@ -19,6 +19,8 @@ import ( "strconv" "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/pingcap/tidb/mysql" ) // Common base error instances. @@ -34,6 +36,8 @@ var ( UnknownSystemVar = ClassVariable.New(CodeUnknownSystemVar, "unknown system variable") MissConnectionID = ClassExpression.New(CodeMissConnectionID, "miss connection id information") + + ErrKeyExists = ClassKV.New(CodeKeyExists, "key already exist") ) // ErrCode represents a specific error type in a error class. @@ -57,6 +61,7 @@ const ( const ( CodeIncompatibleDBFormat ErrCode = iota + 1 CodeNoDataForHandle + CodeKeyExists ) // Variable error codes. @@ -190,6 +195,56 @@ func (e *Error) NotEqual(err error) bool { return !e.Equal(err) } +// ToSQLError convert Error to mysql.SQLError. +func (e *Error) ToSQLError() *mysql.SQLError { + code := e.getMySQLErrorCode() + return mysql.NewErrf(code, e.message) +} + +var defaultMySQLErrorCode uint16 + +func (e *Error) getMySQLErrorCode() uint16 { + codeMap, ok := errClassToMySQLCodes[e.class] + if !ok { + log.Warnf("Unknown error class: %v", e.class) + return defaultMySQLErrorCode + } + code, ok := codeMap[e.code] + if !ok { + log.Warnf("Unknown error class: %v", e.class) + return defaultMySQLErrorCode + } + return code +} + +var ( + // ErrCode to mysql error code map. + parserMySQLErrCodes = map[ErrCode]uint16{} + schemaMySQLErrCodes = map[ErrCode]uint16{} + optimizerMySQLErrCodes = map[ErrCode]uint16{} + executorMySQLErrCodes = map[ErrCode]uint16{} + kvMySQLErrCodes = map[ErrCode]uint16{ + CodeKeyExists: mysql.ErrDupEntry, + } + serverMySQLErrCodes = map[ErrCode]uint16{} + expressionMySQLErrCodes = map[ErrCode]uint16{} + + // ErrClass to code-map map. + errClassToMySQLCodes map[ErrClass](map[ErrCode]uint16) +) + +func init() { + errClassToMySQLCodes = make(map[ErrClass](map[ErrCode]uint16)) + errClassToMySQLCodes[ClassParser] = parserMySQLErrCodes + errClassToMySQLCodes[ClassSchema] = schemaMySQLErrCodes + errClassToMySQLCodes[ClassOptimizer] = optimizerMySQLErrCodes + errClassToMySQLCodes[ClassExecutor] = executorMySQLErrCodes + errClassToMySQLCodes[ClassKV] = kvMySQLErrCodes + errClassToMySQLCodes[ClassServer] = serverMySQLErrCodes + errClassToMySQLCodes[ClassExpression] = expressionMySQLErrCodes + defaultMySQLErrorCode = mysql.ErrUnknown +} + // ErrorEqual returns a boolean indicating whether err1 is equal to err2. func ErrorEqual(err1, err2 error) bool { e1 := errors.Cause(err1) diff --git a/terror/terror_test.go b/terror/terror_test.go index 22ed3f63eb..30a1aa616e 100644 --- a/terror/terror_test.go +++ b/terror/terror_test.go @@ -19,6 +19,7 @@ import ( "github.com/juju/errors" . "github.com/pingcap/check" + "github.com/pingcap/tidb/mysql" ) func TestT(t *testing.T) { @@ -110,3 +111,10 @@ func (s *testTErrorSuite) TestErrorEqual(c *C) { c.Assert(ErrorEqual(te1, te3), IsFalse) c.Assert(ErrorEqual(te3, te4), IsFalse) } + +func (s *testTErrorSuite) TestMySQLErrorCode(c *C) { + ke := ErrKeyExists.Gen("key exists") + me := ke.ToSQLError() + c.Assert(me.Code, Equals, uint16(mysql.ErrDupEntry)) + c.Assert(me.Message, Equals, "key exists") +} diff --git a/tidb-server/server/conn.go b/tidb-server/server/conn.go index 8775920230..4b7e0503b0 100644 --- a/tidb-server/server/conn.go +++ b/tidb-server/server/conn.go @@ -319,10 +319,15 @@ func (cc *clientConn) writeOK() error { } func (cc *clientConn) writeError(e error) error { - var m *mysql.SQLError - var ok bool + var ( + m *mysql.SQLError + te *terror.Error + ok bool + ) originErr := errors.Cause(e) - if m, ok = originErr.(*mysql.SQLError); !ok { + if te, ok = originErr.(*terror.Error); ok { + m = te.ToSQLError() + } else { m = mysql.NewErrf(mysql.ErrUnknown, e.Error()) } diff --git a/tidb-server/server/server_test.go b/tidb-server/server/server_test.go index 6b824d2046..ab2a95a4eb 100644 --- a/tidb-server/server/server_test.go +++ b/tidb-server/server/server_test.go @@ -17,8 +17,9 @@ import ( "database/sql" "testing" - _ "github.com/go-sql-driver/mysql" + "github.com/go-sql-driver/mysql" . "github.com/pingcap/check" + tmysql "github.com/pingcap/tidb/mysql" ) func TestT(t *testing.T) { @@ -220,6 +221,25 @@ func runTestConcurrentUpdate(c *C) { }) } +func runTestErrorCode(c *C) { + runTests(c, dsn, func(dbt *DBTest) { + dbt.mustExec("create table test (c int PRIMARY KEY);") + dbt.mustExec("insert into test values (1);") + txn1, err := dbt.db.Begin() + c.Assert(err, IsNil) + _, err = txn1.Exec("insert into test values(1)") + c.Assert(err, IsNil) + err = txn1.Commit() + checkErrorCode(c, err, tmysql.ErrDupEntry) + }) +} + +func checkErrorCode(c *C, e error, code uint16) { + me, ok := e.(*mysql.MySQLError) + c.Assert(ok, IsTrue) + c.Assert(me.Number, Equals, code) +} + func runTestAuth(c *C) { runTests(c, dsn, func(dbt *DBTest) { dbt.mustExec(`CREATE USER 'test'@'%' IDENTIFIED BY '123';`) diff --git a/tidb-server/server/tidb_test.go b/tidb-server/server/tidb_test.go index 10579a7abc..e60ec6d025 100644 --- a/tidb-server/server/tidb_test.go +++ b/tidb-server/server/tidb_test.go @@ -68,6 +68,10 @@ func (ts *TidbTestSuite) TestConcurrentUpdate(c *C) { runTestConcurrentUpdate(c) } +func (ts *TidbTestSuite) TestErrorCode(c *C) { + runTestErrorCode(c) +} + func (ts *TidbTestSuite) TestAuth(c *C) { runTestAuth(c) }