extension: add error and active roles info to extension.ConnEventInfo (#38752)
close pingcap/tidb#38493
This commit is contained in:
@ -13,6 +13,7 @@ go_library(
|
||||
importpath = "github.com/pingcap/tidb/extension",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//parser/auth",
|
||||
"//sessionctx/variable",
|
||||
"//types",
|
||||
"//util/chunk",
|
||||
|
||||
@ -63,7 +63,7 @@ func (c *bootstrapContext) SessionPool() extension.SessionPool {
|
||||
return c.sessionPool
|
||||
}
|
||||
|
||||
// Bootstrap bootstrap all extensions
|
||||
// Bootstrap bootstraps all extensions
|
||||
func Bootstrap(ctx context.Context, do *domain.Domain) error {
|
||||
extensions, err := extension.GetExtensions()
|
||||
if err != nil {
|
||||
|
||||
@ -29,7 +29,7 @@ func (es *Extensions) Manifests() []*Manifest {
|
||||
return manifests
|
||||
}
|
||||
|
||||
// Bootstrap bootstrap all extensions
|
||||
// Bootstrap bootstraps all extensions
|
||||
func (es *Extensions) Bootstrap(ctx BootstrapContext) error {
|
||||
if es == nil {
|
||||
return nil
|
||||
|
||||
@ -14,10 +14,17 @@
|
||||
|
||||
package extension
|
||||
|
||||
import "github.com/pingcap/tidb/sessionctx/variable"
|
||||
import (
|
||||
"github.com/pingcap/tidb/parser/auth"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
)
|
||||
|
||||
// ConnEventInfo is the connection info for the event
|
||||
type ConnEventInfo variable.ConnectionInfo
|
||||
type ConnEventInfo struct {
|
||||
*variable.ConnectionInfo
|
||||
ActiveRoles []*auth.RoleIdentity
|
||||
Error error
|
||||
}
|
||||
|
||||
// ConnEventTp is the type of the connection event
|
||||
type ConnEventTp uint8
|
||||
@ -60,13 +67,12 @@ type SessionExtensions struct {
|
||||
}
|
||||
|
||||
// OnConnectionEvent will be called when a connection event happens
|
||||
func (es *SessionExtensions) OnConnectionEvent(tp ConnEventTp, info *variable.ConnectionInfo) {
|
||||
func (es *SessionExtensions) OnConnectionEvent(tp ConnEventTp, event *ConnEventInfo) {
|
||||
if es == nil {
|
||||
return
|
||||
}
|
||||
|
||||
eventInfo := ConnEventInfo(*info)
|
||||
for _, fn := range es.connectionEventFuncs {
|
||||
fn(tp, &eventInfo)
|
||||
fn(tp, event)
|
||||
}
|
||||
}
|
||||
|
||||
@ -2510,8 +2510,7 @@ func (cc *clientConn) handleCommonConnectionReset(ctx context.Context) error {
|
||||
connectionInfo := cc.connectInfo()
|
||||
cc.ctx.GetSessionVars().ConnectionInfo = connectionInfo
|
||||
|
||||
cc.extensions.OnConnectionEvent(extension.ConnReset, connectionInfo)
|
||||
|
||||
cc.onExtensionConnEvent(extension.ConnReset, nil)
|
||||
err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
|
||||
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
|
||||
if authPlugin.OnConnectionEvent != nil {
|
||||
|
||||
@ -1663,7 +1663,10 @@ func TestExtensionChangeUser(t *testing.T) {
|
||||
outBuffer.Reset()
|
||||
}
|
||||
|
||||
expectedConnInfo := extension.ConnEventInfo(*cc.connectInfo())
|
||||
expectedConnInfo := extension.ConnEventInfo{
|
||||
ConnectionInfo: cc.connectInfo(),
|
||||
ActiveRoles: []*auth.RoleIdentity{},
|
||||
}
|
||||
expectedConnInfo.User = "user1"
|
||||
expectedConnInfo.DB = "db1"
|
||||
|
||||
@ -1679,7 +1682,9 @@ func TestExtensionChangeUser(t *testing.T) {
|
||||
})
|
||||
require.True(t, logged)
|
||||
require.Equal(t, extension.ConnReset, logTp)
|
||||
require.Equal(t, expectedConnInfo, *logInfo)
|
||||
require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles)
|
||||
require.Equal(t, expectedConnInfo.Error, logInfo.Error)
|
||||
require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo))
|
||||
|
||||
logged = false
|
||||
logTp = 0
|
||||
@ -1697,7 +1702,9 @@ func TestExtensionChangeUser(t *testing.T) {
|
||||
})
|
||||
require.True(t, logged)
|
||||
require.Equal(t, extension.ConnReset, logTp)
|
||||
require.Equal(t, expectedConnInfo, *logInfo)
|
||||
require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles)
|
||||
require.Equal(t, expectedConnInfo.Error, logInfo.Error)
|
||||
require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo))
|
||||
|
||||
logged = false
|
||||
logTp = 0
|
||||
@ -1710,5 +1717,7 @@ func TestExtensionChangeUser(t *testing.T) {
|
||||
})
|
||||
require.True(t, logged)
|
||||
require.Equal(t, extension.ConnReset, logTp)
|
||||
require.Equal(t, expectedConnInfo, *logInfo)
|
||||
require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles)
|
||||
require.Equal(t, expectedConnInfo.Error, logInfo.Error)
|
||||
require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo))
|
||||
}
|
||||
|
||||
@ -55,6 +55,7 @@ import (
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/metrics"
|
||||
"github.com/pingcap/tidb/parser/ast"
|
||||
"github.com/pingcap/tidb/parser/auth"
|
||||
"github.com/pingcap/tidb/parser/mysql"
|
||||
"github.com/pingcap/tidb/parser/terror"
|
||||
"github.com/pingcap/tidb/planner/core"
|
||||
@ -510,7 +511,6 @@ func (s *Server) onConn(conn *clientConn) {
|
||||
terror.Log(conn.Close())
|
||||
return
|
||||
}
|
||||
connectionInfo := conn.connectInfo()
|
||||
|
||||
extensions, err := extension.GetExtensions()
|
||||
if err != nil {
|
||||
@ -522,18 +522,17 @@ func (s *Server) onConn(conn *clientConn) {
|
||||
|
||||
if sessExtensions := extensions.NewSessionExtensions(); sessExtensions != nil {
|
||||
conn.extensions = sessExtensions
|
||||
sessExtensions.OnConnectionEvent(extension.ConnConnected, connectionInfo)
|
||||
conn.onExtensionConnEvent(extension.ConnConnected, nil)
|
||||
defer func() {
|
||||
sessExtensions.OnConnectionEvent(extension.ConnDisconnected, connectionInfo)
|
||||
conn.onExtensionConnEvent(extension.ConnDisconnected, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
ctx := logutil.WithConnID(context.Background(), conn.connectionID)
|
||||
if err := conn.handshake(ctx); err != nil {
|
||||
connectionInfo = conn.connectInfo()
|
||||
conn.extensions.OnConnectionEvent(extension.ConnHandshakeRejected, connectionInfo)
|
||||
conn.onExtensionConnEvent(extension.ConnHandshakeRejected, err)
|
||||
if plugin.IsEnable(plugin.Audit) && conn.getCtx() != nil {
|
||||
conn.getCtx().GetSessionVars().ConnectionInfo = connectionInfo
|
||||
conn.getCtx().GetSessionVars().ConnectionInfo = conn.connectInfo()
|
||||
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
|
||||
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
|
||||
if authPlugin.OnConnectionEvent != nil {
|
||||
@ -578,9 +577,8 @@ func (s *Server) onConn(conn *clientConn) {
|
||||
metrics.ConnGauge.Set(float64(connections))
|
||||
|
||||
sessionVars := conn.ctx.GetSessionVars()
|
||||
connectionInfo = conn.connectInfo()
|
||||
sessionVars.ConnectionInfo = connectionInfo
|
||||
conn.extensions.OnConnectionEvent(extension.ConnHandshakeAccepted, connectionInfo)
|
||||
sessionVars.ConnectionInfo = conn.connectInfo()
|
||||
conn.onExtensionConnEvent(extension.ConnHandshakeAccepted, nil)
|
||||
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
|
||||
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
|
||||
if authPlugin.OnConnectionEvent != nil {
|
||||
@ -638,6 +636,32 @@ func (cc *clientConn) connectInfo() *variable.ConnectionInfo {
|
||||
return connInfo
|
||||
}
|
||||
|
||||
func (cc *clientConn) onExtensionConnEvent(tp extension.ConnEventTp, err error) {
|
||||
if cc.extensions == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var connInfo *variable.ConnectionInfo
|
||||
var activeRoles []*auth.RoleIdentity
|
||||
if ctx := cc.getCtx(); ctx != nil {
|
||||
sessVars := ctx.GetSessionVars()
|
||||
connInfo = sessVars.ConnectionInfo
|
||||
activeRoles = sessVars.ActiveRoles
|
||||
}
|
||||
|
||||
if connInfo == nil {
|
||||
connInfo = cc.connectInfo()
|
||||
}
|
||||
|
||||
info := &extension.ConnEventInfo{
|
||||
ConnectionInfo: connInfo,
|
||||
ActiveRoles: activeRoles,
|
||||
Error: err,
|
||||
}
|
||||
|
||||
cc.extensions.OnConnectionEvent(tp, info)
|
||||
}
|
||||
|
||||
func (s *Server) checkConnectionCount() error {
|
||||
// When the value of Instance.MaxConnections is 0, the number of connections is unlimited.
|
||||
if int(s.cfg.Instance.MaxConnections) == 0 {
|
||||
|
||||
@ -47,8 +47,10 @@ import (
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/parser"
|
||||
"github.com/pingcap/tidb/parser/ast"
|
||||
"github.com/pingcap/tidb/parser/auth"
|
||||
tmysql "github.com/pingcap/tidb/parser/mysql"
|
||||
"github.com/pingcap/tidb/session"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
"github.com/pingcap/tidb/sessiontxn"
|
||||
"github.com/pingcap/tidb/store/mockstore"
|
||||
"github.com/pingcap/tidb/store/mockstore/unistore"
|
||||
@ -2825,28 +2827,50 @@ func TestExtensionConnEvent(t *testing.T) {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var conn1, conn2 extension.ConnEventInfo
|
||||
var expectedConn2 variable.ConnectionInfo
|
||||
logs.check(func() {
|
||||
require.Equal(t, []extension.ConnEventTp{
|
||||
extension.ConnConnected,
|
||||
extension.ConnHandshakeAccepted,
|
||||
}, logs.types)
|
||||
conn1 = logs.infos[0]
|
||||
conn2 = conn1
|
||||
conn2.User = "root"
|
||||
conn2.DB = "test"
|
||||
|
||||
conn1 := logs.infos[0]
|
||||
require.Equal(t, "127.0.0.1", conn1.ClientIP)
|
||||
require.Equal(t, "127.0.0.1", conn1.ServerIP)
|
||||
require.Empty(t, conn1.User)
|
||||
require.Empty(t, conn1.DB)
|
||||
require.Equal(t, conn2, logs.infos[1])
|
||||
require.Equal(t, int(ts.port), conn1.ServerPort)
|
||||
require.NotEqual(t, conn1.ServerPort, conn1.ClientPort)
|
||||
require.NotEmpty(t, conn1.ConnectionID)
|
||||
require.Nil(t, conn1.ActiveRoles)
|
||||
require.NoError(t, conn1.Error)
|
||||
|
||||
expectedConn2 = *(conn1.ConnectionInfo)
|
||||
expectedConn2.User = "root"
|
||||
expectedConn2.DB = "test"
|
||||
require.Equal(t, []*auth.RoleIdentity{}, logs.infos[1].ActiveRoles)
|
||||
require.Nil(t, logs.infos[1].Error)
|
||||
require.Equal(t, expectedConn2, *(logs.infos[1].ConnectionInfo))
|
||||
})
|
||||
|
||||
_, err = conn.ExecContext(context.TODO(), "create role r1@'%'")
|
||||
require.NoError(t, err)
|
||||
_, err = conn.ExecContext(context.TODO(), "grant r1 TO root")
|
||||
require.NoError(t, err)
|
||||
_, err = conn.ExecContext(context.TODO(), "set role all")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.Close())
|
||||
require.NoError(t, db.Close())
|
||||
require.NoError(t, logs.waitConnDisconnected())
|
||||
logs.check(func() {
|
||||
require.Equal(t, conn2, logs.infos[2])
|
||||
require.Equal(t, 3, len(logs.infos))
|
||||
require.Equal(t, 1, len(logs.infos[2].ActiveRoles))
|
||||
require.Equal(t, auth.RoleIdentity{
|
||||
Username: "r1",
|
||||
Hostname: "%",
|
||||
}, *logs.infos[2].ActiveRoles[0])
|
||||
require.Nil(t, logs.infos[2].Error)
|
||||
require.Equal(t, expectedConn2, *(logs.infos[2].ConnectionInfo))
|
||||
})
|
||||
|
||||
// test for login failed
|
||||
@ -2871,16 +2895,22 @@ func TestExtensionConnEvent(t *testing.T) {
|
||||
extension.ConnHandshakeRejected,
|
||||
extension.ConnDisconnected,
|
||||
}, logs.types)
|
||||
conn1 = logs.infos[0]
|
||||
conn2 = conn1
|
||||
conn2.User = "noexist"
|
||||
conn2.DB = "test"
|
||||
|
||||
conn1 := logs.infos[0]
|
||||
require.Equal(t, "127.0.0.1", conn1.ClientIP)
|
||||
require.Equal(t, "127.0.0.1", conn1.ServerIP)
|
||||
require.Empty(t, conn1.User)
|
||||
require.Empty(t, conn1.DB)
|
||||
require.Equal(t, conn2, logs.infos[1])
|
||||
require.Equal(t, conn2, logs.infos[2])
|
||||
require.Equal(t, int(ts.port), conn1.ServerPort)
|
||||
require.NotEqual(t, conn1.ServerPort, conn1.ClientPort)
|
||||
require.NotEmpty(t, conn1.ConnectionID)
|
||||
require.Nil(t, conn1.ActiveRoles)
|
||||
require.NoError(t, conn1.Error)
|
||||
|
||||
expectedConn2 = *(conn1.ConnectionInfo)
|
||||
expectedConn2.User = "noexist"
|
||||
expectedConn2.DB = "test"
|
||||
require.Equal(t, []*auth.RoleIdentity{}, logs.infos[1].ActiveRoles)
|
||||
require.EqualError(t, logs.infos[1].Error, "[server:1045]Access denied for user 'noexist'@'127.0.0.1' (using password: NO)")
|
||||
require.Equal(t, expectedConn2, *(logs.infos[1].ConnectionInfo))
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user