From c8feff878f6a4e7f0ce4f499192b2ca86f8140ca Mon Sep 17 00:00:00 2001 From: Cholerae Hu Date: Tue, 11 Jul 2017 04:39:35 -0500 Subject: [PATCH] conn: fix database info leaking problem (#3699) --- server/conn.go | 6 ++++++ server/driver_tidb.go | 6 ------ server/server_test.go | 18 ++++++++++++++++++ server/tidb_test.go | 4 ++++ 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/server/conn.go b/server/conn.go index 9f002a856a..41bdd1df5e 100644 --- a/server/conn.go +++ b/server/conn.go @@ -317,6 +317,12 @@ func (cc *clientConn) readHandshakeResponse() error { return errors.Trace(errAccessDenied.GenByArgs(cc.user, host, "YES")) } } + if cc.dbname != "" { + _, err = cc.ctx.Execute("use " + cc.dbname) + if err != nil { + return errors.Trace(err) + } + } cc.ctx.SetSessionManager(cc.server) return nil } diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 78c7def574..d7f851639f 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -129,12 +129,6 @@ func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, } session.SetClientCapability(capability) session.SetConnectionID(connID) - if dbname != "" { - _, err = session.Execute("use " + dbname) - if err != nil { - return nil, errors.Trace(err) - } - } tc := &TiDBContext{ session: session, currentDB: dbname, diff --git a/server/server_test.go b/server/server_test.go index 419fa3a71c..9c5d2dcc04 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -487,6 +487,24 @@ func runTestAuth(c *C) { }) } +func runTestIssue3682(c *C) { + runTests(c, dsn, func(dbt *DBTest) { + dbt.mustExec(`CREATE USER 'abc'@'%' IDENTIFIED BY '123';`) + dbt.mustExec(`FLUSH PRIVILEGES;`) + }) + newDsn := "abc:123@tcp(127.0.0.1:4001)/test?strict=true" + runTests(c, newDsn, func(dbt *DBTest) { + dbt.mustExec(`USE mysql;`) + }) + wrongDsn := "abc:456@tcp(127.0.0.1:4001)/a_database_not_exist?strict=true" + db, err := sql.Open("mysql", wrongDsn) + c.Assert(err, IsNil) + defer db.Close() + err = db.Ping() + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Error 1045: Access denied for user 'abc'@'127.0.0.1' (using password: YES)") +} + func runTestIssues(c *C) { // For issue #263 unExistsSchemaDsn := "root@tcp(localhost:4001)/unexists_schema?strict=true" diff --git a/server/tidb_test.go b/server/tidb_test.go index ffe4af812b..77ac0dcb7c 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -160,3 +160,7 @@ func (ts *TidbTestSuite) TestIssue3680(c *C) { c.Assert(err, NotNil) c.Assert(err.Error(), Equals, "Error 1045: Access denied for user 'non_existing_user'@'127.0.0.1' (using password: YES)") } + +func (ts *TidbTestSuite) TestIssue3682(c *C) { + runTestIssue3682(c) +}