170 lines
5.7 KiB
Go
170 lines
5.7 KiB
Go
// Copyright 2022 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package extension
|
|
|
|
import (
|
|
"github.com/pingcap/tidb/pkg/parser"
|
|
"github.com/pingcap/tidb/pkg/parser/ast"
|
|
"github.com/pingcap/tidb/pkg/parser/auth"
|
|
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
|
|
"github.com/pingcap/tidb/pkg/sessionctx/variable"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
)
|
|
|
|
// ConnEventInfo is the connection info for the event
|
|
type ConnEventInfo struct {
|
|
*variable.ConnectionInfo
|
|
SessionAlias string
|
|
ActiveRoles []*auth.RoleIdentity
|
|
Error error
|
|
}
|
|
|
|
// ConnEventTp is the type of the connection event
|
|
type ConnEventTp uint8
|
|
|
|
const (
|
|
// ConnConnected means connection connected, but not handshake yet
|
|
ConnConnected ConnEventTp = iota
|
|
// ConnHandshakeAccepted means connection is accepted after handshake
|
|
ConnHandshakeAccepted
|
|
// ConnHandshakeRejected means connections is rejected after handshake
|
|
ConnHandshakeRejected
|
|
// ConnReset means the connection is reset
|
|
ConnReset
|
|
// ConnDisconnected means the connection is disconnected
|
|
ConnDisconnected
|
|
)
|
|
|
|
// StmtEventTp is the type of the statement event
|
|
type StmtEventTp uint8
|
|
|
|
const (
|
|
// StmtError means the stmt is failed
|
|
StmtError StmtEventTp = iota
|
|
// StmtSuccess means the stmt is successfully executed
|
|
StmtSuccess
|
|
)
|
|
|
|
// StmtEventInfo is the information of stmt event
|
|
type StmtEventInfo interface {
|
|
// User returns the user of the session
|
|
User() *auth.UserIdentity
|
|
// ActiveRoles returns the active roles of the user
|
|
ActiveRoles() []*auth.RoleIdentity
|
|
// CurrentDB returns the current database
|
|
CurrentDB() string
|
|
// ConnectionInfo returns the connection info of the current session
|
|
ConnectionInfo() *variable.ConnectionInfo
|
|
// SessionAlias returns the session alias value set by user
|
|
SessionAlias() string
|
|
// StmtNode returns the parsed ast of the statement
|
|
// When parse error, this method will return a nil value
|
|
StmtNode() ast.StmtNode
|
|
// ExecuteStmtNode will return the `ast.ExecuteStmt` node when the current statement is EXECUTE,
|
|
// otherwise a nil value will be returned
|
|
ExecuteStmtNode() *ast.ExecuteStmt
|
|
// ExecutePreparedStmt will return the prepared stmt node for the EXECUTE statement.
|
|
// If the current statement is not EXECUTE or prepared statement is not found, a nil value will be returned
|
|
ExecutePreparedStmt() ast.StmtNode
|
|
// PreparedParams will return the params for the EXECUTE statement
|
|
PreparedParams() []types.Datum
|
|
// OriginalText will return the text of the statement.
|
|
// Notice that for the EXECUTE statement, the prepared statement text will be used as the return value
|
|
OriginalText() string
|
|
// SQLDigest will return the normalized and redact text of the `OriginalText()`
|
|
SQLDigest() (normalized string, digest *parser.Digest)
|
|
// AffectedRows will return the affected rows of the current statement
|
|
AffectedRows() uint64
|
|
// RelatedTables will return the related tables of the current statement
|
|
// For statements succeeding to build logical plan, it uses the `visitinfo` to get the related tables
|
|
// For statements failing to build logical plan, it traverses the ast node to get the related tables
|
|
RelatedTables() []stmtctx.TableEntry
|
|
// GetError will return the error when the current statement is failed
|
|
GetError() error
|
|
}
|
|
|
|
// SessionHandler is used to listen session events
|
|
type SessionHandler struct {
|
|
OnConnectionEvent func(ConnEventTp, *ConnEventInfo)
|
|
OnStmtEvent func(StmtEventTp, StmtEventInfo)
|
|
}
|
|
|
|
func newSessionExtensions(es *Extensions) *SessionExtensions {
|
|
connExtensions := &SessionExtensions{}
|
|
for _, m := range es.Manifests() {
|
|
if m.sessionHandlerFactory != nil {
|
|
if handler := m.sessionHandlerFactory(); handler != nil {
|
|
if fn := handler.OnConnectionEvent; fn != nil {
|
|
connExtensions.connectionEventFuncs = append(connExtensions.connectionEventFuncs, fn)
|
|
}
|
|
if fn := handler.OnStmtEvent; fn != nil {
|
|
connExtensions.stmtEventFuncs = append(connExtensions.stmtEventFuncs, fn)
|
|
}
|
|
}
|
|
}
|
|
if m.authPlugins != nil {
|
|
connExtensions.authPlugins = make(map[string]*AuthPlugin)
|
|
for _, p := range m.authPlugins {
|
|
connExtensions.authPlugins[p.Name] = p
|
|
}
|
|
}
|
|
}
|
|
return connExtensions
|
|
}
|
|
|
|
// SessionExtensions is the extensions
|
|
type SessionExtensions struct {
|
|
connectionEventFuncs []func(ConnEventTp, *ConnEventInfo)
|
|
stmtEventFuncs []func(StmtEventTp, StmtEventInfo)
|
|
|
|
authPlugins map[string]*AuthPlugin
|
|
}
|
|
|
|
// OnConnectionEvent will be called when a connection event happens
|
|
func (es *SessionExtensions) OnConnectionEvent(tp ConnEventTp, event *ConnEventInfo) {
|
|
if es == nil {
|
|
return
|
|
}
|
|
|
|
for _, fn := range es.connectionEventFuncs {
|
|
fn(tp, event)
|
|
}
|
|
}
|
|
|
|
// HasStmtEventListeners returns a bool that indicates if any stmt event listener exists
|
|
func (es *SessionExtensions) HasStmtEventListeners() bool {
|
|
return es != nil && len(es.stmtEventFuncs) > 0
|
|
}
|
|
|
|
// OnStmtEvent will be called when a stmt event happens
|
|
func (es *SessionExtensions) OnStmtEvent(tp StmtEventTp, event StmtEventInfo) {
|
|
if es == nil {
|
|
return
|
|
}
|
|
|
|
for _, fn := range es.stmtEventFuncs {
|
|
fn(tp, event)
|
|
}
|
|
}
|
|
|
|
// GetAuthPlugin returns the required registered extension auth plugin and whether it exists.
|
|
func (es *SessionExtensions) GetAuthPlugin(name string) (*AuthPlugin, bool) {
|
|
if es == nil {
|
|
return nil, false
|
|
}
|
|
p, ok := es.authPlugins[name]
|
|
return p, ok
|
|
}
|