extension: add error and active roles info to extension.ConnEventInfo (#38752)

close pingcap/tidb#38493
This commit is contained in:
王超
2022-11-01 13:33:59 +08:00
committed by GitHub
parent 38e9aa02ef
commit ecdc0f74ed
8 changed files with 106 additions and 37 deletions

View File

@ -13,6 +13,7 @@ go_library(
importpath = "github.com/pingcap/tidb/extension",
visibility = ["//visibility:public"],
deps = [
"//parser/auth",
"//sessionctx/variable",
"//types",
"//util/chunk",

View File

@ -63,7 +63,7 @@ func (c *bootstrapContext) SessionPool() extension.SessionPool {
return c.sessionPool
}
// Bootstrap bootstrap all extensions
// Bootstrap bootstraps all extensions
func Bootstrap(ctx context.Context, do *domain.Domain) error {
extensions, err := extension.GetExtensions()
if err != nil {

View File

@ -29,7 +29,7 @@ func (es *Extensions) Manifests() []*Manifest {
return manifests
}
// Bootstrap bootstrap all extensions
// Bootstrap bootstraps all extensions
func (es *Extensions) Bootstrap(ctx BootstrapContext) error {
if es == nil {
return nil

View File

@ -14,10 +14,17 @@
package extension
import "github.com/pingcap/tidb/sessionctx/variable"
import (
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/sessionctx/variable"
)
// ConnEventInfo is the connection info for the event
type ConnEventInfo variable.ConnectionInfo
type ConnEventInfo struct {
*variable.ConnectionInfo
ActiveRoles []*auth.RoleIdentity
Error error
}
// ConnEventTp is the type of the connection event
type ConnEventTp uint8
@ -60,13 +67,12 @@ type SessionExtensions struct {
}
// OnConnectionEvent will be called when a connection event happens
func (es *SessionExtensions) OnConnectionEvent(tp ConnEventTp, info *variable.ConnectionInfo) {
func (es *SessionExtensions) OnConnectionEvent(tp ConnEventTp, event *ConnEventInfo) {
if es == nil {
return
}
eventInfo := ConnEventInfo(*info)
for _, fn := range es.connectionEventFuncs {
fn(tp, &eventInfo)
fn(tp, event)
}
}

View File

@ -2510,8 +2510,7 @@ func (cc *clientConn) handleCommonConnectionReset(ctx context.Context) error {
connectionInfo := cc.connectInfo()
cc.ctx.GetSessionVars().ConnectionInfo = connectionInfo
cc.extensions.OnConnectionEvent(extension.ConnReset, connectionInfo)
cc.onExtensionConnEvent(extension.ConnReset, nil)
err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {

View File

@ -1663,7 +1663,10 @@ func TestExtensionChangeUser(t *testing.T) {
outBuffer.Reset()
}
expectedConnInfo := extension.ConnEventInfo(*cc.connectInfo())
expectedConnInfo := extension.ConnEventInfo{
ConnectionInfo: cc.connectInfo(),
ActiveRoles: []*auth.RoleIdentity{},
}
expectedConnInfo.User = "user1"
expectedConnInfo.DB = "db1"
@ -1679,7 +1682,9 @@ func TestExtensionChangeUser(t *testing.T) {
})
require.True(t, logged)
require.Equal(t, extension.ConnReset, logTp)
require.Equal(t, expectedConnInfo, *logInfo)
require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles)
require.Equal(t, expectedConnInfo.Error, logInfo.Error)
require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo))
logged = false
logTp = 0
@ -1697,7 +1702,9 @@ func TestExtensionChangeUser(t *testing.T) {
})
require.True(t, logged)
require.Equal(t, extension.ConnReset, logTp)
require.Equal(t, expectedConnInfo, *logInfo)
require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles)
require.Equal(t, expectedConnInfo.Error, logInfo.Error)
require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo))
logged = false
logTp = 0
@ -1710,5 +1717,7 @@ func TestExtensionChangeUser(t *testing.T) {
})
require.True(t, logged)
require.Equal(t, extension.ConnReset, logTp)
require.Equal(t, expectedConnInfo, *logInfo)
require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles)
require.Equal(t, expectedConnInfo.Error, logInfo.Error)
require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo))
}

View File

@ -55,6 +55,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/planner/core"
@ -510,7 +511,6 @@ func (s *Server) onConn(conn *clientConn) {
terror.Log(conn.Close())
return
}
connectionInfo := conn.connectInfo()
extensions, err := extension.GetExtensions()
if err != nil {
@ -522,18 +522,17 @@ func (s *Server) onConn(conn *clientConn) {
if sessExtensions := extensions.NewSessionExtensions(); sessExtensions != nil {
conn.extensions = sessExtensions
sessExtensions.OnConnectionEvent(extension.ConnConnected, connectionInfo)
conn.onExtensionConnEvent(extension.ConnConnected, nil)
defer func() {
sessExtensions.OnConnectionEvent(extension.ConnDisconnected, connectionInfo)
conn.onExtensionConnEvent(extension.ConnDisconnected, nil)
}()
}
ctx := logutil.WithConnID(context.Background(), conn.connectionID)
if err := conn.handshake(ctx); err != nil {
connectionInfo = conn.connectInfo()
conn.extensions.OnConnectionEvent(extension.ConnHandshakeRejected, connectionInfo)
conn.onExtensionConnEvent(extension.ConnHandshakeRejected, err)
if plugin.IsEnable(plugin.Audit) && conn.getCtx() != nil {
conn.getCtx().GetSessionVars().ConnectionInfo = connectionInfo
conn.getCtx().GetSessionVars().ConnectionInfo = conn.connectInfo()
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
@ -578,9 +577,8 @@ func (s *Server) onConn(conn *clientConn) {
metrics.ConnGauge.Set(float64(connections))
sessionVars := conn.ctx.GetSessionVars()
connectionInfo = conn.connectInfo()
sessionVars.ConnectionInfo = connectionInfo
conn.extensions.OnConnectionEvent(extension.ConnHandshakeAccepted, connectionInfo)
sessionVars.ConnectionInfo = conn.connectInfo()
conn.onExtensionConnEvent(extension.ConnHandshakeAccepted, nil)
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
@ -638,6 +636,32 @@ func (cc *clientConn) connectInfo() *variable.ConnectionInfo {
return connInfo
}
func (cc *clientConn) onExtensionConnEvent(tp extension.ConnEventTp, err error) {
if cc.extensions == nil {
return
}
var connInfo *variable.ConnectionInfo
var activeRoles []*auth.RoleIdentity
if ctx := cc.getCtx(); ctx != nil {
sessVars := ctx.GetSessionVars()
connInfo = sessVars.ConnectionInfo
activeRoles = sessVars.ActiveRoles
}
if connInfo == nil {
connInfo = cc.connectInfo()
}
info := &extension.ConnEventInfo{
ConnectionInfo: connInfo,
ActiveRoles: activeRoles,
Error: err,
}
cc.extensions.OnConnectionEvent(tp, info)
}
func (s *Server) checkConnectionCount() error {
// When the value of Instance.MaxConnections is 0, the number of connections is unlimited.
if int(s.cfg.Instance.MaxConnections) == 0 {

View File

@ -47,8 +47,10 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/auth"
tmysql "github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/sessiontxn"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/store/mockstore/unistore"
@ -2825,28 +2827,50 @@ func TestExtensionConnEvent(t *testing.T) {
_ = conn.Close()
}()
var conn1, conn2 extension.ConnEventInfo
var expectedConn2 variable.ConnectionInfo
logs.check(func() {
require.Equal(t, []extension.ConnEventTp{
extension.ConnConnected,
extension.ConnHandshakeAccepted,
}, logs.types)
conn1 = logs.infos[0]
conn2 = conn1
conn2.User = "root"
conn2.DB = "test"
conn1 := logs.infos[0]
require.Equal(t, "127.0.0.1", conn1.ClientIP)
require.Equal(t, "127.0.0.1", conn1.ServerIP)
require.Empty(t, conn1.User)
require.Empty(t, conn1.DB)
require.Equal(t, conn2, logs.infos[1])
require.Equal(t, int(ts.port), conn1.ServerPort)
require.NotEqual(t, conn1.ServerPort, conn1.ClientPort)
require.NotEmpty(t, conn1.ConnectionID)
require.Nil(t, conn1.ActiveRoles)
require.NoError(t, conn1.Error)
expectedConn2 = *(conn1.ConnectionInfo)
expectedConn2.User = "root"
expectedConn2.DB = "test"
require.Equal(t, []*auth.RoleIdentity{}, logs.infos[1].ActiveRoles)
require.Nil(t, logs.infos[1].Error)
require.Equal(t, expectedConn2, *(logs.infos[1].ConnectionInfo))
})
_, err = conn.ExecContext(context.TODO(), "create role r1@'%'")
require.NoError(t, err)
_, err = conn.ExecContext(context.TODO(), "grant r1 TO root")
require.NoError(t, err)
_, err = conn.ExecContext(context.TODO(), "set role all")
require.NoError(t, err)
require.NoError(t, conn.Close())
require.NoError(t, db.Close())
require.NoError(t, logs.waitConnDisconnected())
logs.check(func() {
require.Equal(t, conn2, logs.infos[2])
require.Equal(t, 3, len(logs.infos))
require.Equal(t, 1, len(logs.infos[2].ActiveRoles))
require.Equal(t, auth.RoleIdentity{
Username: "r1",
Hostname: "%",
}, *logs.infos[2].ActiveRoles[0])
require.Nil(t, logs.infos[2].Error)
require.Equal(t, expectedConn2, *(logs.infos[2].ConnectionInfo))
})
// test for login failed
@ -2871,16 +2895,22 @@ func TestExtensionConnEvent(t *testing.T) {
extension.ConnHandshakeRejected,
extension.ConnDisconnected,
}, logs.types)
conn1 = logs.infos[0]
conn2 = conn1
conn2.User = "noexist"
conn2.DB = "test"
conn1 := logs.infos[0]
require.Equal(t, "127.0.0.1", conn1.ClientIP)
require.Equal(t, "127.0.0.1", conn1.ServerIP)
require.Empty(t, conn1.User)
require.Empty(t, conn1.DB)
require.Equal(t, conn2, logs.infos[1])
require.Equal(t, conn2, logs.infos[2])
require.Equal(t, int(ts.port), conn1.ServerPort)
require.NotEqual(t, conn1.ServerPort, conn1.ClientPort)
require.NotEmpty(t, conn1.ConnectionID)
require.Nil(t, conn1.ActiveRoles)
require.NoError(t, conn1.Error)
expectedConn2 = *(conn1.ConnectionInfo)
expectedConn2.User = "noexist"
expectedConn2.DB = "test"
require.Equal(t, []*auth.RoleIdentity{}, logs.infos[1].ActiveRoles)
require.EqualError(t, logs.infos[1].Error, "[server:1045]Access denied for user 'noexist'@'127.0.0.1' (using password: NO)")
require.Equal(t, expectedConn2, *(logs.infos[1].ConnectionInfo))
})
}