// Copyright 2013 The ql Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSES/QL-LICENSE file. // Copyright 2015 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // See the License for the specific language governing permissions and // limitations under the License. package tidb import ( "bytes" "encoding/json" "fmt" "strings" "sync" "sync/atomic" "time" "github.com/juju/errors" "github.com/ngaut/log" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/autocommit" "github.com/pingcap/tidb/sessionctx/db" "github.com/pingcap/tidb/sessionctx/forupdate" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/stmt" "github.com/pingcap/tidb/stmt/stmts" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/types" ) // Session context type Session interface { Status() uint16 // Flag of current status, such as autocommit LastInsertID() uint64 // Last inserted auto_increment id AffectedRows() uint64 // Affected rows by lastest executed stmt Execute(sql string) ([]rset.Recordset, error) // Execute a sql statement String() string // For debug FinishTxn(rollback bool) error // For execute prepare statement in binary protocol PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) // Execute a prepared statement ExecutePreparedStmt(stmtID uint32, param ...interface{}) (rset.Recordset, error) DropPreparedStmt(stmtID uint32) error SetClientCapability(uint32) // Set client capability flags Close() error Retry() error Auth(user string, auth []byte, salt []byte) bool } var ( _ Session = (*session)(nil) sessionID int64 sessionMu sync.Mutex ) type stmtRecord struct { stmtID uint32 st stmt.Statement params []interface{} } type stmtHistory struct { history []*stmtRecord } func (h *stmtHistory) add(stmtID uint32, st stmt.Statement, params ...interface{}) { s := &stmtRecord{ stmtID: stmtID, st: st, params: append(([]interface{})(nil), params...), } h.history = append(h.history, s) } func (h *stmtHistory) reset() { if len(h.history) > 0 { h.history = h.history[:0] } } func (h *stmtHistory) clone() *stmtHistory { nh := *h nh.history = make([]*stmtRecord, len(h.history)) copy(nh.history, h.history) return &nh } type session struct { txn kv.Transaction // Current transaction args []interface{} // Statment execution args, this should be cleaned up after exec values map[fmt.Stringer]interface{} store kv.Storage sid int64 history stmtHistory initing bool // Running bootstrap using this session. } func (s *session) Status() uint16 { return variable.GetSessionVars(s).Status } func (s *session) LastInsertID() uint64 { return variable.GetSessionVars(s).LastInsertID } func (s *session) AffectedRows() uint64 { return variable.GetSessionVars(s).AffectedRows } func (s *session) resetHistory() { s.ClearValue(forupdate.ForUpdateKey) s.history.reset() } func (s *session) SetClientCapability(capability uint32) { variable.GetSessionVars(s).ClientCapability = capability } func (s *session) FinishTxn(rollback bool) error { // transaction has already been committed or rolled back if s.txn == nil { return nil } defer func() { s.txn = nil variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusInTrans, false) }() if rollback { return s.txn.Rollback() } err := s.txn.Commit() if err != nil { log.Warnf("txn:%s, %v", s.txn, err) return errors.Trace(err) } s.resetHistory() return nil } func (s *session) String() string { // TODO: how to print binded context in values appropriately? data := map[string]interface{}{ "currDBName": db.GetCurrentSchema(s), "sid": s.sid, } if s.txn != nil { // if txn is committed or rolled back, txn is nil. data["txn"] = s.txn.String() } b, _ := json.MarshalIndent(data, "", " ") return string(b) } func needRetry(st stmt.Statement) bool { switch st.(type) { case *stmts.PreparedStmt, *stmts.ShowStmt, *stmts.DoStmt: return false default: return true } } func isPreparedStmt(st stmt.Statement) bool { switch st.(type) { case *stmts.PreparedStmt: return true default: return false } } func (s *session) Retry() error { nh := s.history.clone() defer func() { s.history.history = nh.history }() if forUpdate := s.Value(forupdate.ForUpdateKey); forUpdate != nil { return errors.Errorf("can not retry select for update statement") } var err error for { s.resetHistory() s.FinishTxn(true) success := true for _, sr := range nh.history { st := sr.st // Skip prepare statement if !needRetry(st) { continue } log.Warnf("Retry %s", st.OriginText()) _, err = runStmt(s, st) if err != nil { if terror.ErrorEqual(err, kv.ErrConditionNotMatch) { success = false break } log.Warnf("session:%v, err:%v", s, err) return errors.Trace(err) } } if success { break } } return nil } // ExecRestrictedSQL implements SQLHelper interface. // This is used for executing some restricted sql statements. func (s *session) ExecRestrictedSQL(ctx context.Context, sql string) (rset.Recordset, error) { if ctx.Value(&sqlexec.RestrictedSQLExecutorKeyType{}) != nil { // We do not support run this function concurrently. // TODO: Maybe we should remove this restriction latter. return nil, errors.New("Should not call ExecRestrictedSQL concurrently.") } statements, err := Compile(ctx, sql) if err != nil { log.Errorf("Compile %s with error: %v", sql, err) return nil, errors.Trace(err) } if len(statements) != 1 { log.Errorf("ExecRestrictedSQL only executes one statement. Too many/few statement in %s", sql) return nil, errors.New("Wrong number of statement.") } st := statements[0] // Check statement for some restriction // For example only support DML on system meta table. // TODO: Add more restrictions. log.Infof("Executing %s [%s]", st.OriginText(), sql) ctx.SetValue(&sqlexec.RestrictedSQLExecutorKeyType{}, true) defer ctx.ClearValue(&sqlexec.RestrictedSQLExecutorKeyType{}) rs, err := st.Exec(ctx) return rs, errors.Trace(err) } // GetGlobalSysVar implements RestrictedSQLExecutor.GetGlobalSysVar interface. func (s *session) GetGlobalSysVar(ctx context.Context, name string) (string, error) { sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s";`, mysql.SystemDB, mysql.GlobalVariablesTable, name) rs, err := s.ExecRestrictedSQL(ctx, sql) if err != nil { return "", errors.Trace(err) } defer rs.Close() row, err := rs.Next() if err != nil { return "", errors.Trace(err) } if row == nil { return "", fmt.Errorf("Unknown sys var: %s", name) } value, err := types.ToString(row.Data[0]) if err != nil { return "", errors.Trace(err) } return value, nil } // SetGlobalSysVar implements RestrictedSQLExecutor.SetGlobalSysVar interface. func (s *session) SetGlobalSysVar(ctx context.Context, name string, value string) error { sql := fmt.Sprintf(`UPDATE %s.%s SET VARIABLE_VALUE="%s" WHERE VARIABLE_NAME="%s";`, mysql.SystemDB, mysql.GlobalVariablesTable, value, strings.ToLower(name)) _, err := s.ExecRestrictedSQL(ctx, sql) return errors.Trace(err) } // IsAutocommit checks if it is in the auto-commit mode. func (s *session) isAutocommit(ctx context.Context) bool { if ctx.Value(&sqlexec.RestrictedSQLExecutorKeyType{}) != nil { return false } autocommit, ok := variable.GetSessionVars(ctx).Systems["autocommit"] if !ok { if s.initing { return false } var err error autocommit, err = s.GetGlobalSysVar(ctx, "autocommit") if err != nil { log.Errorf("Get global sys var error: %v", err) return false } ok = true } if ok && (autocommit == "ON" || autocommit == "on" || autocommit == "1") { variable.GetSessionVars(ctx).SetStatusFlag(mysql.ServerStatusAutocommit, true) return true } variable.GetSessionVars(ctx).SetStatusFlag(mysql.ServerStatusAutocommit, false) return false } func (s *session) ShouldAutocommit(ctx context.Context) bool { if ctx.Value(&sqlexec.RestrictedSQLExecutorKeyType{}) != nil { return false } // With START TRANSACTION, autocommit remains disabled until you end // the transaction with COMMIT or ROLLBACK. if variable.GetSessionVars(ctx).Status&mysql.ServerStatusInTrans == 0 && s.isAutocommit(ctx) { return true } return false } func (s *session) Execute(sql string) ([]rset.Recordset, error) { statements, err := Compile(s, sql) if err != nil { log.Errorf("Syntax error: %s", sql) log.Errorf("Error occurs at %s.", err) return nil, errors.Trace(err) } var rs []rset.Recordset for _, st := range statements { r, err := runStmt(s, st) if err != nil { log.Warnf("session:%v, err:%v", s, err) return nil, errors.Trace(err) } // Record executed query if isPreparedStmt(st) { ps := st.(*stmts.PreparedStmt) s.history.add(ps.ID, st) } else { s.history.add(0, st) } if r != nil { rs = append(rs, r) } } return rs, nil } // For execute prepare statement in binary protocol func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) { return prepareStmt(s, sql) } // checkArgs makes sure all the arguments' types are known and can be handled. // integer types are converted to int64 and uint64, time.Time is converted to mysql.Time. // time.Duration is converted to mysql.Duration, other known types are leaved as it is. func checkArgs(args ...interface{}) error { for i, v := range args { switch x := v.(type) { case bool: if x { args[i] = int64(1) } else { args[i] = int64(0) } case int8: args[i] = int64(x) case int16: args[i] = int64(x) case int32: args[i] = int64(x) case int: args[i] = int64(x) case uint8: args[i] = uint64(x) case uint16: args[i] = uint64(x) case uint32: args[i] = uint64(x) case uint: args[i] = uint64(x) case int64: case uint64: case float32: case float64: case string: case []byte: case time.Duration: args[i] = mysql.Duration{Duration: x} case time.Time: args[i] = mysql.Time{Time: x, Type: mysql.TypeDatetime} case nil: default: return errors.Errorf("cannot use arg[%d] (type %T):unsupported type", i, v) } } return nil } // Execute a prepared statement func (s *session) ExecutePreparedStmt(stmtID uint32, args ...interface{}) (rset.Recordset, error) { err := checkArgs(args...) if err != nil { return nil, err } st := &stmts.ExecuteStmt{ID: stmtID} s.history.add(stmtID, st, args...) return runStmt(s, st, args...) } func (s *session) DropPreparedStmt(stmtID uint32) error { return dropPreparedStmt(s, stmtID) } // If forceNew is true, GetTxn() must return a new transaction. // In this situation, if current transaction is still in progress, // there will be an implicit commit and create a new transaction. func (s *session) GetTxn(forceNew bool) (kv.Transaction, error) { var err error if s.txn == nil { s.resetHistory() s.txn, err = s.store.Begin() if err != nil { return nil, err } if !s.isAutocommit(s) { variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusInTrans, true) } log.Infof("New txn:%s in session:%d", s.txn, s.sid) return s.txn, nil } if forceNew { err = s.txn.Commit() variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusInTrans, false) if err != nil { return nil, err } s.resetHistory() s.txn, err = s.store.Begin() if err != nil { return nil, err } if !s.isAutocommit(s) { variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusInTrans, true) } log.Warnf("Force new txn:%s in session:%d", s.txn, s.sid) } return s.txn, nil } func (s *session) SetValue(key fmt.Stringer, value interface{}) { s.values[key] = value } func (s *session) Value(key fmt.Stringer) interface{} { value := s.values[key] return value } func (s *session) ClearValue(key fmt.Stringer) { delete(s.values, key) } // Close function does some clean work when session end. func (s *session) Close() error { return s.FinishTxn(true) } func (s *session) Auth(user string, auth []byte, salt []byte) bool { strs := strings.Split(user, "@") if len(strs) != 2 { log.Warnf("Invalid format for user: %s", user) return false } // Get user password. name := strs[0] host := strs[1] authSQL := fmt.Sprintf("SELECT Password FROM %s.%s WHERE User=\"%s\" and Host=\"%s\";", mysql.SystemDB, mysql.UserTable, name, host) rs, err := s.Execute(authSQL) if err != nil { log.Warnf("Encounter error when auth user %s. Error: %v", user, err) return false } if len(rs) == 0 { return false } row, err := rs[0].Next() if err != nil { log.Warnf("Encounter error when auth user %s. Error: %v", user, err) return false } if row == nil || len(row.Data) == 0 { return false } pwd, ok := row.Data[0].(string) if !ok { return false } hpwd, err := util.DecodePassword(pwd) if err != nil { log.Errorf("Decode password string error %v", err) return false } checkAuth := util.CalcPassword(salt, hpwd) if !bytes.Equal(auth, checkAuth) { return false } variable.GetSessionVars(s).SetCurrentUser(user) return true } // CreateSession creates a new session environment. func CreateSession(store kv.Storage) (Session, error) { s := &session{ values: make(map[fmt.Stringer]interface{}), store: store, sid: atomic.AddInt64(&sessionID, 1), } domain, err := domap.Get(store) if err != nil { return nil, err } sessionctx.BindDomain(s, domain) variable.BindSessionVars(s) variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusAutocommit, true) // session implements autocommit.Checker. Bind it to ctx autocommit.BindAutocommitChecker(s, s) sessionMu.Lock() defer sessionMu.Unlock() _, ok := storeBootstrapped[store.UUID()] if !ok { s.initing = true bootstrap(s) s.initing = false storeBootstrapped[store.UUID()] = true } // TODO: Add auth here privChecker := &privileges.UserPrivileges{} privilege.BindPrivilegeChecker(s, privChecker) return s, nil }