From ecdc0f74edceff881d01f55fc2bb14df7baa2e7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Tue, 1 Nov 2022 13:33:59 +0800 Subject: [PATCH] extension: add error and active roles info to `extension.ConnEventInfo` (#38752) close pingcap/tidb#38493 --- extension/BUILD.bazel | 1 + extension/extensionimpl/bootstrap.go | 2 +- extension/extensions.go | 2 +- extension/session.go | 16 +++++--- server/conn.go | 3 +- server/conn_test.go | 17 ++++++-- server/server.go | 42 ++++++++++++++----- server/tidb_test.go | 60 +++++++++++++++++++++------- 8 files changed, 106 insertions(+), 37 deletions(-) diff --git a/extension/BUILD.bazel b/extension/BUILD.bazel index 0ec88aff5a..a21f6ea2f4 100644 --- a/extension/BUILD.bazel +++ b/extension/BUILD.bazel @@ -13,6 +13,7 @@ go_library( importpath = "github.com/pingcap/tidb/extension", visibility = ["//visibility:public"], deps = [ + "//parser/auth", "//sessionctx/variable", "//types", "//util/chunk", diff --git a/extension/extensionimpl/bootstrap.go b/extension/extensionimpl/bootstrap.go index 1332bc6f9e..8b6b154f98 100644 --- a/extension/extensionimpl/bootstrap.go +++ b/extension/extensionimpl/bootstrap.go @@ -63,7 +63,7 @@ func (c *bootstrapContext) SessionPool() extension.SessionPool { return c.sessionPool } -// Bootstrap bootstrap all extensions +// Bootstrap bootstraps all extensions func Bootstrap(ctx context.Context, do *domain.Domain) error { extensions, err := extension.GetExtensions() if err != nil { diff --git a/extension/extensions.go b/extension/extensions.go index e977d20dc4..68bcffd585 100644 --- a/extension/extensions.go +++ b/extension/extensions.go @@ -29,7 +29,7 @@ func (es *Extensions) Manifests() []*Manifest { return manifests } -// Bootstrap bootstrap all extensions +// Bootstrap bootstraps all extensions func (es *Extensions) Bootstrap(ctx BootstrapContext) error { if es == nil { return nil diff --git a/extension/session.go b/extension/session.go index f0217f6da9..65fddb7531 100644 --- a/extension/session.go +++ b/extension/session.go @@ -14,10 +14,17 @@ package extension -import "github.com/pingcap/tidb/sessionctx/variable" +import ( + "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/sessionctx/variable" +) // ConnEventInfo is the connection info for the event -type ConnEventInfo variable.ConnectionInfo +type ConnEventInfo struct { + *variable.ConnectionInfo + ActiveRoles []*auth.RoleIdentity + Error error +} // ConnEventTp is the type of the connection event type ConnEventTp uint8 @@ -60,13 +67,12 @@ type SessionExtensions struct { } // OnConnectionEvent will be called when a connection event happens -func (es *SessionExtensions) OnConnectionEvent(tp ConnEventTp, info *variable.ConnectionInfo) { +func (es *SessionExtensions) OnConnectionEvent(tp ConnEventTp, event *ConnEventInfo) { if es == nil { return } - eventInfo := ConnEventInfo(*info) for _, fn := range es.connectionEventFuncs { - fn(tp, &eventInfo) + fn(tp, event) } } diff --git a/server/conn.go b/server/conn.go index c2e3f1db9b..33d07655e5 100644 --- a/server/conn.go +++ b/server/conn.go @@ -2510,8 +2510,7 @@ func (cc *clientConn) handleCommonConnectionReset(ctx context.Context) error { connectionInfo := cc.connectInfo() cc.ctx.GetSessionVars().ConnectionInfo = connectionInfo - cc.extensions.OnConnectionEvent(extension.ConnReset, connectionInfo) - + cc.onExtensionConnEvent(extension.ConnReset, nil) err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { diff --git a/server/conn_test.go b/server/conn_test.go index 0f9acc6f18..28e023ea8c 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -1663,7 +1663,10 @@ func TestExtensionChangeUser(t *testing.T) { outBuffer.Reset() } - expectedConnInfo := extension.ConnEventInfo(*cc.connectInfo()) + expectedConnInfo := extension.ConnEventInfo{ + ConnectionInfo: cc.connectInfo(), + ActiveRoles: []*auth.RoleIdentity{}, + } expectedConnInfo.User = "user1" expectedConnInfo.DB = "db1" @@ -1679,7 +1682,9 @@ func TestExtensionChangeUser(t *testing.T) { }) require.True(t, logged) require.Equal(t, extension.ConnReset, logTp) - require.Equal(t, expectedConnInfo, *logInfo) + require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles) + require.Equal(t, expectedConnInfo.Error, logInfo.Error) + require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo)) logged = false logTp = 0 @@ -1697,7 +1702,9 @@ func TestExtensionChangeUser(t *testing.T) { }) require.True(t, logged) require.Equal(t, extension.ConnReset, logTp) - require.Equal(t, expectedConnInfo, *logInfo) + require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles) + require.Equal(t, expectedConnInfo.Error, logInfo.Error) + require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo)) logged = false logTp = 0 @@ -1710,5 +1717,7 @@ func TestExtensionChangeUser(t *testing.T) { }) require.True(t, logged) require.Equal(t, extension.ConnReset, logTp) - require.Equal(t, expectedConnInfo, *logInfo) + require.Equal(t, expectedConnInfo.ActiveRoles, logInfo.ActiveRoles) + require.Equal(t, expectedConnInfo.Error, logInfo.Error) + require.Equal(t, *(expectedConnInfo.ConnectionInfo), *(logInfo.ConnectionInfo)) } diff --git a/server/server.go b/server/server.go index 9ed4b96391..5efef687ec 100644 --- a/server/server.go +++ b/server/server.go @@ -55,6 +55,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/planner/core" @@ -510,7 +511,6 @@ func (s *Server) onConn(conn *clientConn) { terror.Log(conn.Close()) return } - connectionInfo := conn.connectInfo() extensions, err := extension.GetExtensions() if err != nil { @@ -522,18 +522,17 @@ func (s *Server) onConn(conn *clientConn) { if sessExtensions := extensions.NewSessionExtensions(); sessExtensions != nil { conn.extensions = sessExtensions - sessExtensions.OnConnectionEvent(extension.ConnConnected, connectionInfo) + conn.onExtensionConnEvent(extension.ConnConnected, nil) defer func() { - sessExtensions.OnConnectionEvent(extension.ConnDisconnected, connectionInfo) + conn.onExtensionConnEvent(extension.ConnDisconnected, nil) }() } ctx := logutil.WithConnID(context.Background(), conn.connectionID) if err := conn.handshake(ctx); err != nil { - connectionInfo = conn.connectInfo() - conn.extensions.OnConnectionEvent(extension.ConnHandshakeRejected, connectionInfo) + conn.onExtensionConnEvent(extension.ConnHandshakeRejected, err) if plugin.IsEnable(plugin.Audit) && conn.getCtx() != nil { - conn.getCtx().GetSessionVars().ConnectionInfo = connectionInfo + conn.getCtx().GetSessionVars().ConnectionInfo = conn.connectInfo() err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { @@ -578,9 +577,8 @@ func (s *Server) onConn(conn *clientConn) { metrics.ConnGauge.Set(float64(connections)) sessionVars := conn.ctx.GetSessionVars() - connectionInfo = conn.connectInfo() - sessionVars.ConnectionInfo = connectionInfo - conn.extensions.OnConnectionEvent(extension.ConnHandshakeAccepted, connectionInfo) + sessionVars.ConnectionInfo = conn.connectInfo() + conn.onExtensionConnEvent(extension.ConnHandshakeAccepted, nil) err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { @@ -638,6 +636,32 @@ func (cc *clientConn) connectInfo() *variable.ConnectionInfo { return connInfo } +func (cc *clientConn) onExtensionConnEvent(tp extension.ConnEventTp, err error) { + if cc.extensions == nil { + return + } + + var connInfo *variable.ConnectionInfo + var activeRoles []*auth.RoleIdentity + if ctx := cc.getCtx(); ctx != nil { + sessVars := ctx.GetSessionVars() + connInfo = sessVars.ConnectionInfo + activeRoles = sessVars.ActiveRoles + } + + if connInfo == nil { + connInfo = cc.connectInfo() + } + + info := &extension.ConnEventInfo{ + ConnectionInfo: connInfo, + ActiveRoles: activeRoles, + Error: err, + } + + cc.extensions.OnConnectionEvent(tp, info) +} + func (s *Server) checkConnectionCount() error { // When the value of Instance.MaxConnections is 0, the number of connections is unlimited. if int(s.cfg.Instance.MaxConnections) == 0 { diff --git a/server/tidb_test.go b/server/tidb_test.go index 293406fe87..633dab3cef 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -47,8 +47,10 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/auth" tmysql "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/sessiontxn" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/mockstore/unistore" @@ -2825,28 +2827,50 @@ func TestExtensionConnEvent(t *testing.T) { _ = conn.Close() }() - var conn1, conn2 extension.ConnEventInfo + var expectedConn2 variable.ConnectionInfo logs.check(func() { require.Equal(t, []extension.ConnEventTp{ extension.ConnConnected, extension.ConnHandshakeAccepted, }, logs.types) - conn1 = logs.infos[0] - conn2 = conn1 - conn2.User = "root" - conn2.DB = "test" - + conn1 := logs.infos[0] require.Equal(t, "127.0.0.1", conn1.ClientIP) require.Equal(t, "127.0.0.1", conn1.ServerIP) require.Empty(t, conn1.User) require.Empty(t, conn1.DB) - require.Equal(t, conn2, logs.infos[1]) + require.Equal(t, int(ts.port), conn1.ServerPort) + require.NotEqual(t, conn1.ServerPort, conn1.ClientPort) + require.NotEmpty(t, conn1.ConnectionID) + require.Nil(t, conn1.ActiveRoles) + require.NoError(t, conn1.Error) + + expectedConn2 = *(conn1.ConnectionInfo) + expectedConn2.User = "root" + expectedConn2.DB = "test" + require.Equal(t, []*auth.RoleIdentity{}, logs.infos[1].ActiveRoles) + require.Nil(t, logs.infos[1].Error) + require.Equal(t, expectedConn2, *(logs.infos[1].ConnectionInfo)) }) + + _, err = conn.ExecContext(context.TODO(), "create role r1@'%'") + require.NoError(t, err) + _, err = conn.ExecContext(context.TODO(), "grant r1 TO root") + require.NoError(t, err) + _, err = conn.ExecContext(context.TODO(), "set role all") + require.NoError(t, err) + require.NoError(t, conn.Close()) require.NoError(t, db.Close()) require.NoError(t, logs.waitConnDisconnected()) logs.check(func() { - require.Equal(t, conn2, logs.infos[2]) + require.Equal(t, 3, len(logs.infos)) + require.Equal(t, 1, len(logs.infos[2].ActiveRoles)) + require.Equal(t, auth.RoleIdentity{ + Username: "r1", + Hostname: "%", + }, *logs.infos[2].ActiveRoles[0]) + require.Nil(t, logs.infos[2].Error) + require.Equal(t, expectedConn2, *(logs.infos[2].ConnectionInfo)) }) // test for login failed @@ -2871,16 +2895,22 @@ func TestExtensionConnEvent(t *testing.T) { extension.ConnHandshakeRejected, extension.ConnDisconnected, }, logs.types) - conn1 = logs.infos[0] - conn2 = conn1 - conn2.User = "noexist" - conn2.DB = "test" - + conn1 := logs.infos[0] require.Equal(t, "127.0.0.1", conn1.ClientIP) require.Equal(t, "127.0.0.1", conn1.ServerIP) require.Empty(t, conn1.User) require.Empty(t, conn1.DB) - require.Equal(t, conn2, logs.infos[1]) - require.Equal(t, conn2, logs.infos[2]) + require.Equal(t, int(ts.port), conn1.ServerPort) + require.NotEqual(t, conn1.ServerPort, conn1.ClientPort) + require.NotEmpty(t, conn1.ConnectionID) + require.Nil(t, conn1.ActiveRoles) + require.NoError(t, conn1.Error) + + expectedConn2 = *(conn1.ConnectionInfo) + expectedConn2.User = "noexist" + expectedConn2.DB = "test" + require.Equal(t, []*auth.RoleIdentity{}, logs.infos[1].ActiveRoles) + require.EqualError(t, logs.infos[1].Error, "[server:1045]Access denied for user 'noexist'@'127.0.0.1' (using password: NO)") + require.Equal(t, expectedConn2, *(logs.infos[1].ConnectionInfo)) }) }