diff --git a/expression/builtin/info.go b/expression/builtin/info.go index 3e5e31a091..14687d6b42 100644 --- a/expression/builtin/info.go +++ b/expression/builtin/info.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/sessionctx/db" "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/terror" ) // See: https://dev.mysql.com/doc/refman/5.7/en/information-functions.html @@ -69,30 +68,11 @@ func builtinUser(args []interface{}, data map[interface{}]interface{}) (v interf return variable.GetSessionVars(ctx).User, nil } -// connectionIDKeyType is a dummy type to avoid naming collision in context. -type connectionIDKeyType int - -// String defines a Stringer function for debugging and pretty printing. -func (k connectionIDKeyType) String() string { - return "connection_id" -} - -// ConnectionIDKey is the key for get connection id from context -const ConnectionIDKey connectionIDKeyType = 0 - func builtinConnectionID(args []interface{}, data map[interface{}]interface{}) (v interface{}, err error) { c, ok := data[ExprEvalArgCtx] if !ok { return nil, errors.Errorf("Missing ExprEvalArgCtx when evalue builtin") } ctx := c.(context.Context) - idValue := ctx.Value(ConnectionIDKey) - if idValue == nil { - return nil, terror.MissConnectionID - } - id, ok := idValue.(int64) - if !ok { - return nil, terror.MissConnectionID.Gen("connection id is not int64 but %T", idValue) - } - return id, nil + return variable.GetSessionVars(ctx).ConnectionID, nil } diff --git a/expression/builtin/info_test.go b/expression/builtin/info_test.go index 260596aff8..93eca4577b 100644 --- a/expression/builtin/info_test.go +++ b/expression/builtin/info_test.go @@ -86,9 +86,11 @@ func (s *testBuiltinSuite) TestConnectionID(c *C) { ctx := mock.NewContext() m := map[interface{}]interface{}{} variable.BindSessionVars(ctx) - ctx.SetValue(ConnectionIDKey, int64(1)) + sessionVars := variable.GetSessionVars(ctx) + sessionVars.ConnectionID = uint64(1) + m[ExprEvalArgCtx] = ctx v, err := builtinConnectionID(nil, m) c.Assert(err, IsNil) - c.Assert(v, Equals, int64(1)) + c.Assert(v, Equals, uint64(1)) } diff --git a/session.go b/session.go index 730b18e617..5fd018cef8 100644 --- a/session.go +++ b/session.go @@ -29,7 +29,6 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/expression/builtin" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" @@ -65,6 +64,7 @@ type Session interface { ExecutePreparedStmt(stmtID uint32, param ...interface{}) (rset.Recordset, error) DropPreparedStmt(stmtID uint32) error SetClientCapability(uint32) // Set client capability flags + SetConnectionID(uint64) Close() error Retry() error Auth(user string, auth []byte, salt []byte) bool @@ -145,6 +145,10 @@ func (s *session) SetClientCapability(capability uint32) { variable.GetSessionVars(s).ClientCapability = capability } +func (s *session) SetConnectionID(connectionID uint64) { + variable.GetSessionVars(s).ConnectionID = connectionID +} + func (s *session) FinishTxn(rollback bool) error { // transaction has already been committed or rolled back if s.txn == nil { @@ -580,9 +584,6 @@ func CreateSession(store kv.Storage) (Session, error) { variable.BindSessionVars(s) variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusAutocommit, true) - // set connection id - s.SetValue(builtin.ConnectionIDKey, s.sid) - // session implements variable.GlobalVarAccessor. Bind it to ctx. variable.BindGlobalVarAccessor(s, s) diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 23c04e0ebb..d0c3acfb9a 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -36,6 +36,9 @@ type SessionVars struct { // Client capability ClientCapability uint32 + // Connection ID + ConnectionID uint64 + // Found rows FoundRows uint64 diff --git a/tidb-server/server/conn.go b/tidb-server/server/conn.go index 4b7e0503b0..3943173b17 100644 --- a/tidb-server/server/conn.go +++ b/tidb-server/server/conn.go @@ -191,7 +191,7 @@ func (cc *clientConn) readHandshakeResponse() error { } } // Open session and do auth - cc.ctx, err = cc.server.driver.OpenCtx(cc.capability, uint8(cc.collation), cc.dbname) + cc.ctx, err = cc.server.driver.OpenCtx(uint64(cc.connectionID), cc.capability, uint8(cc.collation), cc.dbname) if err != nil { cc.Close() return errors.Trace(err) diff --git a/tidb-server/server/driver.go b/tidb-server/server/driver.go index 1f1c9a15ff..d00444c46c 100644 --- a/tidb-server/server/driver.go +++ b/tidb-server/server/driver.go @@ -15,8 +15,8 @@ package server // IDriver opens IContext. type IDriver interface { - // OpenCtx opens an IContext with client capability, collation and dbname. - OpenCtx(capability uint32, collation uint8, dbname string) (IContext, error) + // OpenCtx opens an IContext with connection id, client capability, collation and dbname. + OpenCtx(connID uint64, capability uint32, collation uint8, dbname string) (IContext, error) } // IContext is the interface to execute commant. diff --git a/tidb-server/server/driver_tidb.go b/tidb-server/server/driver_tidb.go index c7f6fce4d6..aa59f6f189 100644 --- a/tidb-server/server/driver_tidb.go +++ b/tidb-server/server/driver_tidb.go @@ -110,9 +110,10 @@ func (ts *TiDBStatement) Close() error { } // OpenCtx implements IDriver. -func (qd *TiDBDriver) OpenCtx(capability uint32, collation uint8, dbname string) (IContext, error) { +func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string) (IContext, error) { session, _ := tidb.CreateSession(qd.store) session.SetClientCapability(capability) + session.SetConnectionID(connID) if dbname != "" { _, err := session.Execute("use " + dbname) if err != nil {